@@ -367,8 +367,14 @@ def test(
367367 # NOTE, trainer.test takes only lightning based checkpoint.
368368 # So, it can't take the OTX1.x checkpoint.
369369 if checkpoint is not None and not is_ir_ckpt :
370+ kwargs_user_input : dict [str , Any ] = {}
371+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
372+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
373+ # TODO (sungchul): revisit for better solution
374+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
375+
370376 model_cls = model .__class__
371- model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** model . hparams )
377+ model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** kwargs_user_input )
372378
373379 if model .label_info != self .datamodule .label_info :
374380 if (
@@ -462,8 +468,14 @@ def predict(
462468 datamodule = self ._auto_configurator .update_ov_subset_pipeline (datamodule = datamodule , subset = "test" )
463469
464470 if checkpoint is not None and not is_ir_ckpt :
471+ kwargs_user_input : dict [str , Any ] = {}
472+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
473+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
474+ # TODO (sungchul): revisit for better solution
475+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
476+
465477 model_cls = model .__class__
466- model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** model . hparams )
478+ model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** kwargs_user_input )
467479
468480 if model .label_info != self .datamodule .label_info :
469481 msg = (
@@ -574,11 +586,17 @@ def export(
574586 )
575587
576588 if not is_ir_ckpt :
589+ kwargs_user_input : dict [str , Any ] = {}
590+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
591+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
592+ # TODO (sungchul): revisit for better solution
593+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
594+
577595 model_cls = self .model .__class__
578596 self .model = model_cls .load_from_checkpoint (
579597 checkpoint_path = checkpoint ,
580598 map_location = "cpu" ,
581- ** self . model . hparams ,
599+ ** kwargs_user_input ,
582600 )
583601 self .model .eval ()
584602
@@ -742,8 +760,14 @@ def explain(
742760 model = self ._auto_configurator .get_ov_model (model_name = str (checkpoint ), label_info = datamodule .label_info )
743761
744762 if checkpoint is not None and not is_ir_ckpt :
763+ kwargs_user_input : dict [str , Any ] = {}
764+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
765+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
766+ # TODO (sungchul): revisit for better solution
767+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
768+
745769 model_cls = model .__class__
746- model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** model . hparams )
770+ model = model_cls .load_from_checkpoint (checkpoint_path = checkpoint , ** kwargs_user_input )
747771
748772 if model .label_info != self .datamodule .label_info :
749773 msg = (
@@ -845,11 +869,17 @@ def benchmark(
845869 )
846870
847871 if not is_ir_ckpt :
872+ kwargs_user_input : dict [str , Any ] = {}
873+ if self .task == OTXTaskType .ZERO_SHOT_VISUAL_PROMPTING :
874+ # to update user's custom infer_reference_info_root through cli for zero-shot learning
875+ # TODO (sungchul): revisit for better solution
876+ kwargs_user_input .update (infer_reference_info_root = self .model .infer_reference_info_root )
877+
848878 model_cls = self .model .__class__
849879 self .model = model_cls .load_from_checkpoint (
850880 checkpoint_path = checkpoint ,
851881 map_location = "cpu" ,
852- ** self . model . hparams ,
882+ ** kwargs_user_input ,
853883 )
854884 elif isinstance (self .model , OVModel ):
855885 msg = "To run benchmark on OV model, checkpoint must be specified."
0 commit comments