1212import logging
1313import warnings
1414from abc import abstractmethod
15+ from dataclasses import asdict
1516from typing import TYPE_CHECKING , Any , Callable , Literal , Sequence
1617
1718import numpy as np
@@ -106,13 +107,13 @@ class OTXModel(LightningModule):
106107
107108 def __init__ (
108109 self ,
109- label_info : LabelInfoTypes ,
110+ label_info : LabelInfoTypes | dict ,
110111 input_size : tuple [int , int ] | None = None ,
111112 optimizer : OptimizerCallable = DefaultOptimizerCallable ,
112113 scheduler : LRSchedulerCallable | LRSchedulerListCallable = DefaultSchedulerCallable ,
113114 metric : MetricCallable = NullMetricCallable ,
114115 torch_compile : bool = False ,
115- tile_config : TileConfig = TileConfig (enable_tiler = False ),
116+ tile_config : TileConfig | dict = TileConfig (enable_tiler = False ),
116117 ) -> None :
117118 super ().__init__ ()
118119
@@ -121,21 +122,21 @@ def __init__(
121122 self .input_size = input_size
122123 self .classification_layers : dict [str , dict [str , Any ]] = {}
123124 self .model = self ._create_model ()
124- self .optimizer_callable = ensure_callable (optimizer )
125- self .scheduler_callable = ensure_callable (scheduler )
126- self .metric_callable = ensure_callable (metric )
125+ self .optimizer_callable : OptimizerCallable = ensure_callable (optimizer )
126+ self .scheduler_callable : LRSchedulerCallable = ensure_callable (scheduler )
127+ self .metric_callable : MetricCallable = ensure_callable (metric )
127128
128129 self .torch_compile = torch_compile
129130 self ._explain_mode = False
130131
131132 # NOTE: To guarantee immutablility of the default value
133+ if isinstance (tile_config , dict ):
134+ tile_config = TileConfig (** tile_config )
132135 self ._tile_config = tile_config .clone ()
133-
134- # this line allows to access init params with 'self.hparams' attribute
135- # also ensures init params will be stored in ckpt
136- # TODO(vinnamki): Ticket no. 138995: MetricCallable should be saved in the checkpoint
137- # so that it can retrieve it from the checkpoint
138- self .save_hyperparameters (logger = False , ignore = ["optimizer" , "scheduler" , "metric" ])
136+ self .save_hyperparameters (
137+ logger = False ,
138+ ignore = ["optimizer" , "scheduler" , "metric" , "label_info" , "tile_config" ],
139+ )
139140
140141 def training_step (self , batch : T_OTXBatchDataEntity , batch_idx : int ) -> Tensor | None :
141142 """Step for model training."""
@@ -376,38 +377,42 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
376377 compiled_state_dict = checkpoint ["state_dict" ]
377378 checkpoint ["state_dict" ] = remove_state_dict_prefix (compiled_state_dict , "_orig_mod." )
378379 super ().on_save_checkpoint (checkpoint )
379-
380- checkpoint ["label_info" ] = self .label_info
380+ checkpoint ["hyper_parameters" ]["label_info" ] = asdict (self .label_info )
381381 checkpoint ["otx_version" ] = __version__
382- checkpoint ["tile_config" ] = self .tile_config
382+ checkpoint ["hyper_parameters" ]["tile_config" ] = asdict (self .tile_config )
383+ checkpoint .pop ("datamodule_hparams_name" , None )
384+ checkpoint .pop (
385+ "datamodule_hyper_parameters" ,
386+ None ,
387+ ) # Remove datamodule_hyper_parameters to prevent storing OTX classes
383388
384389 def on_load_checkpoint (self , checkpoint : dict [str , Any ]) -> None :
385390 """Callback on loading checkpoint."""
386391 super ().on_load_checkpoint (checkpoint )
387-
388- if ckpt_label_info := checkpoint .get ("label_info" ):
389- if isinstance (ckpt_label_info , LabelInfo ) and not hasattr (ckpt_label_info , "label_ids" ):
390- # NOTE: This is for backward compatibility
391- ckpt_label_info = LabelInfo (
392- label_groups = ckpt_label_info .label_groups ,
393- label_names = ckpt_label_info .label_names ,
394- label_ids = ckpt_label_info .label_names ,
395- )
396- self ._label_info = ckpt_label_info
397-
398- if ckpt_tile_config := checkpoint .get ("tile_config" ):
399- self .tile_config = ckpt_tile_config
392+ hyper_parameters = checkpoint .get ("hyper_parameters" , None )
393+ if hyper_parameters :
394+ if ckpt_label_info := hyper_parameters .get ("label_info" ):
395+ self ._label_info = self ._dispatch_label_info (ckpt_label_info )
396+ if ckpt_tile_config := hyper_parameters .get ("tile_config" ):
397+ if isinstance (ckpt_tile_config , dict ):
398+ ckpt_tile_config = TileConfig (** ckpt_tile_config )
399+ self .tile_config = ckpt_tile_config
400400
401401 def load_state_dict_incrementally (self , ckpt : dict [str , Any ], * args , ** kwargs ) -> None :
402402 """Load state dict incrementally."""
403403 ckpt_label_info : LabelInfo | None = (
404- ckpt .get ("label_info" ) if not is_ckpt_from_otx_v1 (ckpt ) else self .get_ckpt_label_info_v1 (ckpt )
404+ ckpt .get ("hyper_parameters" , {}).get ("label_info" )
405+ if not is_ckpt_from_otx_v1 (ckpt )
406+ else self .get_ckpt_label_info_v1 (ckpt )
405407 )
406408
407409 if ckpt_label_info is None :
408410 msg = "Checkpoint should have `label_info`."
409411 raise ValueError (msg , ckpt_label_info )
410412
413+ if isinstance (ckpt_label_info , dict ):
414+ ckpt_label_info = LabelInfo (** ckpt_label_info )
415+
411416 if not hasattr (ckpt_label_info , "label_ids" ):
412417 msg = "Loading checkpoint from OTX < 2.2.1, label_ids are assigned automatically"
413418 logger .info (msg )
@@ -447,10 +452,10 @@ def load_state_dict(self, ckpt: dict[str, Any], *args, **kwargs) -> None:
447452 warnings .warn (msg , stacklevel = 2 )
448453 state_dict = self .load_from_otx_v1_ckpt (ckpt )
449454 elif is_ckpt_for_finetuning (ckpt ):
455+ self .on_load_checkpoint (ckpt )
450456 state_dict = ckpt ["state_dict" ]
451457 else :
452458 state_dict = ckpt
453-
454459 return super ().load_state_dict (state_dict , * args , ** kwargs )
455460
456461 def load_from_otx_v1_ckpt (self , ckpt : dict [str , Any ]) -> dict :
@@ -828,6 +833,11 @@ def get_dummy_input(self, batch_size: int = 1) -> OTXBatchDataEntity:
828833
829834 @staticmethod
830835 def _dispatch_label_info (label_info : LabelInfoTypes ) -> LabelInfo :
836+ if isinstance (label_info , dict ):
837+ if "label_ids" not in label_info :
838+ # NOTE: This is for backward compatibility
839+ label_info ["label_ids" ] = label_info ["label_names" ]
840+ return LabelInfo (** label_info )
831841 if isinstance (label_info , int ):
832842 return LabelInfo .from_num_classes (num_classes = label_info )
833843 if isinstance (label_info , Sequence ) and all (isinstance (name , str ) for name in label_info ):
@@ -837,6 +847,9 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
837847 label_ids = [str (i ) for i in range (len (label_info ))],
838848 )
839849 if isinstance (label_info , LabelInfo ):
850+ if not hasattr (label_info , "label_ids" ):
851+ # NOTE: This is for backward compatibility
852+ label_info .label_ids = label_info .label_names
840853 return label_info
841854
842855 raise TypeError (label_info )
0 commit comments