44import time
55from collections .abc import Callable
66from contextlib import AbstractContextManager
7+ from dataclasses import dataclass
78from pathlib import Path
89from uuid import UUID
910
1415
1516from app .core .run import ExecutionContext
1617from app .models import DatasetItemAnnotationStatus
17- from app .services import BaseWeightsService , DatasetService
18+ from app .schemas .model import TrainingStatus
19+ from app .schemas .project import TaskBase
20+ from app .services import BaseWeightsService , DatasetService , ModelRevisionMetadata , ModelService
1821
1922from .base import Trainer , step
23+ from .models import TrainingParams
2024from .subset_assignment import SplitRatios , SubsetAssigner , SubsetService
2125
2226MODEL_WEIGHTS_PATH = "model_weights_path"
2327
2428
29+ @dataclass (frozen = True )
30+ class DatasetInfo :
31+ training : Dataset
32+ validation : Dataset
33+ testing : Dataset
34+ revision_id : UUID
35+
36+
2537class OTXTrainer (Trainer ):
2638 """OTX-specific trainer implementation."""
2739
@@ -31,6 +43,7 @@ def __init__(
3143 base_weights_service : BaseWeightsService ,
3244 subset_service : SubsetService ,
3345 dataset_service : DatasetService ,
46+ model_service : ModelService ,
3447 subset_assigner : SubsetAssigner ,
3548 db_session_factory : Callable [[], AbstractContextManager [Session ]],
3649 ):
@@ -39,26 +52,22 @@ def __init__(
3952 self ._base_weights_service = base_weights_service
4053 self ._subset_service = subset_service
4154 self ._dataset_service = dataset_service
55+ self ._model_service = model_service
4256 self ._subset_assigner = subset_assigner
4357 self ._db_session_factory = db_session_factory
44- self ._training_dataset : Dataset | None = None
45- self ._validation_dataset : Dataset | None = None
46- self ._testing_dataset : Dataset | None = None
4758
4859 @step ("Prepare Model Weights" )
49- def prepare_weights (self ) -> Path :
60+ def prepare_weights (self , training_params : TrainingParams ) -> Path :
5061 """
5162 Prepare weights for training based on training parameters.
5263
5364 If a parent model revision ID is provided, it fetches the weights from the parent model.
5465 Otherwise, it retrieves the base weights for the specified model architecture.
5566 """
56- if self ._training_params is None :
57- raise ValueError ("Training parameters not set" )
58- parent_model_revision_id = self ._training_params .parent_model_revision_id
59- task = self ._training_params .task
60- model_architecture_id = self ._training_params .model_architecture_id
61- project_id = self ._training_params .project_id
67+ parent_model_revision_id = training_params .parent_model_revision_id
68+ task = training_params .task
69+ model_architecture_id = training_params .model_architecture_id
70+ project_id = training_params .project_id
6271 if parent_model_revision_id is None :
6372 return self ._base_weights_service .get_local_weights_path (
6473 task = task .task_type , model_manifest_id = model_architecture_id
@@ -74,17 +83,11 @@ def prepare_weights(self) -> Path:
7483 return weights_path
7584
7685 @step ("Assign Dataset Subsets" )
77- def assign_subsets (self ) -> None :
86+ def assign_subsets (self , project_id : UUID ) -> None :
7887 """Assigning subsets to all unassigned dataset items in the project dataset."""
79- if self ._training_params is None :
80- raise ValueError ("Training parameters not set" )
81- project_id = self ._training_params .project_id
82- self .report_progress ("Retrieving unassigned items" )
83- if project_id is None :
84- raise ValueError ("Project ID must be provided for subset assignment" )
85-
8688 with self ._db_session_factory () as db :
8789 self ._subset_service .set_db_session (db )
90+ self .report_progress ("Retrieving unassigned items" )
8891 unassigned_items = self ._subset_service .get_unassigned_items_with_labels (project_id )
8992
9093 if not unassigned_items :
@@ -112,30 +115,41 @@ def assign_subsets(self) -> None:
112115 self .report_progress (f"Successfully assigned { len (assignments )} items to subsets" )
113116
114117 @step ("Create Training Dataset" )
115- def create_training_dataset (self ) -> None :
118+ def create_training_dataset (self , project_id : UUID , task : TaskBase ) -> DatasetInfo :
116119 """Create datasets for training, validation, and testing."""
117- if self ._training_params is None :
118- raise ValueError ("Training parameters not set" )
119- project_id = self ._training_params .project_id
120- if project_id is None :
121- raise ValueError ("Project ID must be provided" )
122- task = self ._training_params .task
123-
124120 with self ._db_session_factory () as db :
125121 self ._dataset_service .set_db_session (db )
126122 dm_dataset = self ._dataset_service .get_dm_dataset (project_id , task , DatasetItemAnnotationStatus .REVIEWED )
127- self ._training_dataset = dm_dataset .filter_by_subset (Subset .TRAINING )
128- self ._validation_dataset = dm_dataset .filter_by_subset (Subset .VALIDATION )
129- self ._testing_dataset = dm_dataset .filter_by_subset (Subset .TESTING )
130- self ._dataset_service .save_revision (project_id , dm_dataset )
123+ return DatasetInfo (
124+ training = dm_dataset .filter_by_subset (Subset .TRAINING ),
125+ validation = dm_dataset .filter_by_subset (Subset .VALIDATION ),
126+ testing = dm_dataset .filter_by_subset (Subset .TESTING ),
127+ revision_id = self ._dataset_service .save_revision (project_id , dm_dataset ),
128+ )
129+
130+ @step ("Prepare Model Metadata" )
131+ def prepare_model (self , training_params : TrainingParams , dataset_revision_id : UUID ) -> None :
132+ if training_params .project_id is None :
133+ raise ValueError ("Project ID must be provided for model preparation" )
134+ with self ._db_session_factory () as db :
135+ self ._model_service .set_db_session (db )
136+ self ._model_service .create_revision (
137+ ModelRevisionMetadata (
138+ model_id = training_params .model_id ,
139+ project_id = training_params .project_id ,
140+ architecture_id = training_params .model_architecture_id ,
141+ parent_revision_id = training_params .parent_model_revision_id ,
142+ training_configuration = None , # TODO: to be set when config is added
143+ dataset_revision_id = dataset_revision_id ,
144+ training_status = TrainingStatus .NOT_STARTED ,
145+ )
146+ )
131147
132148 @step ("Train Model with OTX" )
133- def train_model (self ) -> None :
149+ def train_model (self , training_params : TrainingParams ) -> None :
134150 """Execute OTX model training."""
135- if self ._training_params is None :
136- raise ValueError ("Training parameters not set" )
137151 # Simulate training with progress reporting
138- job_id = self . _training_params .job_id
152+ job_id = training_params .job_id
139153 step_count = 20
140154 for i in range (step_count ):
141155 time .sleep (1 )
@@ -145,12 +159,17 @@ def train_model(self) -> None:
145159
146160 def run (self , ctx : ExecutionContext ) -> None :
147161 self ._ctx = ctx
148- self ._training_params = self ._get_training_params (ctx )
149-
150- self .prepare_weights ()
151- self .assign_subsets ()
152- self .create_training_dataset ()
153- self .train_model ()
162+ training_params = self ._get_training_params (ctx )
163+ project_id = training_params .project_id
164+ if project_id is None :
165+ raise ValueError ("Project ID must be provided in training parameters" )
166+ task = training_params .task
167+
168+ self .prepare_weights (training_params )
169+ self .assign_subsets (project_id )
170+ dataset_info = self .create_training_dataset (project_id , task )
171+ self .prepare_model (training_params , dataset_info .revision_id )
172+ self .train_model (training_params )
154173
155174 @staticmethod
156175 def __build_model_weights_path (data_dir : Path , project_id : UUID , model_id : UUID ) -> Path :
0 commit comments