88from pathlib import Path
99from uuid import UUID
1010
11+ import yaml
1112from datumaro .experimental import Dataset
1213from datumaro .experimental .fields import Subset
1314from loguru import logger
15+ from otx .types .task import OTXTaskType
1416from sqlalchemy .orm import Session
1517
18+ from app .core .jobs import JobType
1619from app .core .run import ExecutionContext
17- from app .models import DatasetItemAnnotationStatus
20+ from app .models import DatasetItemAnnotationStatus , TaskType
21+ from app .models .training_configuration .configuration import TrainingConfiguration
1822from app .schemas .model import TrainingStatus
1923from app .schemas .project import TaskBase
20- from app .services import BaseWeightsService , DatasetService , ModelRevisionMetadata , ModelService
24+ from app .services import (
25+ BaseWeightsService ,
26+ DatasetService ,
27+ ModelRevisionMetadata ,
28+ ModelService ,
29+ TrainingConfigurationService ,
30+ )
2131
2232from .base import Trainer , step
2333from .models import TrainingParams
2636MODEL_WEIGHTS_PATH = "model_weights_path"
2737
2838
39+ # TODO: Consider adopting some lightweight DI framework
40+ # As the number of constructor dependencies grows and start violating ruff rules, we should evaluate DI frameworks like:
41+ # - dependency-injector (https://python-dependency-injector.ets-labs.org/)
42+ # - injector (https://github.com/python-injector/injector)
43+ # - python-inject (https://github.com/ivankorobkov/python-inject)
44+ @dataclass (frozen = True )
45+ class TrainingDependencies :
46+ data_dir : Path
47+ base_weights_service : BaseWeightsService
48+ subset_service : SubsetService
49+ dataset_service : DatasetService
50+ model_service : ModelService
51+ training_configuration_service : TrainingConfigurationService
52+ subset_assigner : SubsetAssigner
53+ db_session_factory : Callable [[], AbstractContextManager [Session ]]
54+
55+
2956@dataclass (frozen = True )
3057class DatasetInfo :
3158 training : Dataset
@@ -39,22 +66,17 @@ class OTXTrainer(Trainer):
3966
4067 def __init__ (
4168 self ,
42- data_dir : Path ,
43- base_weights_service : BaseWeightsService ,
44- subset_service : SubsetService ,
45- dataset_service : DatasetService ,
46- model_service : ModelService ,
47- subset_assigner : SubsetAssigner ,
48- db_session_factory : Callable [[], AbstractContextManager [Session ]],
69+ training_deps : TrainingDependencies ,
4970 ):
5071 super ().__init__ ()
51- self ._data_dir = data_dir
52- self ._base_weights_service = base_weights_service
53- self ._subset_service = subset_service
54- self ._dataset_service = dataset_service
55- self ._model_service = model_service
56- self ._subset_assigner = subset_assigner
57- self ._db_session_factory = db_session_factory
72+ self ._data_dir = training_deps .data_dir
73+ self ._base_weights_service = training_deps .base_weights_service
74+ self ._subset_service = training_deps .subset_service
75+ self ._dataset_service = training_deps .dataset_service
76+ self ._model_service = training_deps .model_service
77+ self ._training_configuration_service = training_deps .training_configuration_service
78+ self ._subset_assigner = training_deps .subset_assigner
79+ self ._db_session_factory = training_deps .db_session_factory
5880
5981 @step ("Prepare Model Weights" )
6082 def prepare_weights (self , training_params : TrainingParams ) -> Path :
@@ -127,19 +149,28 @@ def create_training_dataset(self, project_id: UUID, task: TaskBase) -> DatasetIn
127149 revision_id = self ._dataset_service .save_revision (project_id , dm_dataset ),
128150 )
129151
130- @step ("Prepare Model Metadata " )
152+ @step ("Prepare Model and Training Configuration " )
131153 def prepare_model (self , training_params : TrainingParams , dataset_revision_id : UUID ) -> None :
132154 if training_params .project_id is None :
133155 raise ValueError ("Project ID must be provided for model preparation" )
134156 with self ._db_session_factory () as db :
157+ self ._training_configuration_service .set_db_session (db )
135158 self ._model_service .set_db_session (db )
159+ configuration = self ._training_configuration_service .get_training_configuration (
160+ project_id = training_params .project_id ,
161+ model_architecture_id = training_params .model_architecture_id ,
162+ )
163+ config_path = self .__build_model_config_path (
164+ self ._data_dir , training_params .project_id , training_params .model_id
165+ )
166+ self .__persist_configuration (configuration , config_path , training_params .task )
136167 self ._model_service .create_revision (
137168 ModelRevisionMetadata (
138169 model_id = training_params .model_id ,
139170 project_id = training_params .project_id ,
140171 architecture_id = training_params .model_architecture_id ,
141172 parent_revision_id = training_params .parent_model_revision_id ,
142- training_configuration = None , # TODO: to be set when config is added
173+ training_configuration = configuration ,
143174 dataset_revision_id = dataset_revision_id ,
144175 training_status = TrainingStatus .NOT_STARTED ,
145176 )
@@ -172,5 +203,31 @@ def run(self, ctx: ExecutionContext) -> None:
172203 self .train_model (training_params )
173204
174205 @staticmethod
175- def __build_model_weights_path (data_dir : Path , project_id : UUID , model_id : UUID ) -> Path :
176- return data_dir / "projects" / str (project_id ) / "models" / str (model_id ) / "model.pth"
206+ def __base_model_path (data_dir : Path , project_id : UUID , model_id : UUID ) -> Path :
207+ return data_dir / "projects" / str (project_id ) / "models" / str (model_id )
208+
209+ @classmethod
210+ def __build_model_weights_path (cls , data_dir : Path , project_id : UUID , model_id : UUID ) -> Path :
211+ return cls .__base_model_path (data_dir , project_id , model_id ) / "model.pth"
212+
213+ @classmethod
214+ def __build_model_config_path (cls , data_dir : Path , project_id : UUID , model_id : UUID ) -> Path :
215+ return cls .__base_model_path (data_dir , project_id , model_id ) / "config.yaml"
216+
217+ @staticmethod
218+ def __persist_configuration (configuration : TrainingConfiguration , config_path : Path , task : TaskBase ) -> None :
219+ extended_config = configuration .model_dump (exclude_none = True )
220+ extended_config ["job_type" ] = JobType .TRAIN .value
221+ match task .task_type :
222+ case TaskType .CLASSIFICATION :
223+ if task .exclusive_labels :
224+ extended_config ["sub_task_type" ] = OTXTaskType .MULTI_CLASS_CLS .value
225+ else :
226+ extended_config ["sub_task_type" ] = OTXTaskType .MULTI_LABEL_CLS .value
227+ case TaskType .DETECTION :
228+ extended_config ["sub_task_type" ] = OTXTaskType .DETECTION .value
229+ case TaskType .INSTANCE_SEGMENTATION :
230+ extended_config ["sub_task_type" ] = OTXTaskType .INSTANCE_SEGMENTATION .value
231+ config_path .parent .mkdir (parents = True , exist_ok = True )
232+ with open (config_path , "w" ) as f :
233+ yaml .dump (extended_config , f , default_flow_style = False )
0 commit comments