@@ -1202,16 +1202,24 @@ def infer_architecture_from_checkpoint(checkpoint):
12021202 else :
12031203 print ("⚠️ Could not infer architecture, using fallback methods..." )
12041204
1205- # Fallback 1: Try model design
1206- model_design = best_model .get_model_design ()
1207- if model_design :
1208- print (f"📋 Using model design: { model_design } " )
1209- hidden_size = int (model_design .get ('hidden_size' , 256 ))
1210- num_layers = int (model_design .get ('num_layers' , 4 ))
1211- dropout_rate = float (model_design .get ('dropout_rate' , 0.1 ))
1212- use_layer_norm = bool (model_design .get ('use_layer_norm' , False ))
1213- attention_dropout = float (model_design .get ('attention_dropout' , 0.1 ))
1214- batch_size = int (model_design .get ('batch_size' , 32 ))
1205+ # Fallback 1: Try model config
1206+ try :
1207+ # Try to get model configuration from the model object
1208+ model_config = model .config_dict if hasattr (model , 'config_dict' ) else {}
1209+ if not model_config and hasattr (model , 'get_model_config' ):
1210+ model_config = model .get_model_config () or {}
1211+ except Exception as config_error :
1212+ print (f"⚠️ Could not retrieve model config: { config_error } " )
1213+ model_config = {}
1214+
1215+ if model_config :
1216+ print (f"📋 Using model config: { model_config } " )
1217+ hidden_size = int (model_config .get ('hidden_size' , 256 ))
1218+ num_layers = int (model_config .get ('num_layers' , 4 ))
1219+ dropout_rate = float (model_config .get ('dropout_rate' , 0.1 ))
1220+ use_layer_norm = bool (model_config .get ('use_layer_norm' , False ))
1221+ attention_dropout = float (model_config .get ('attention_dropout' , 0.1 ))
1222+ batch_size = int (model_config .get ('batch_size' , 32 ))
12151223 else :
12161224 print ("📋 Using task parameters..." )
12171225 # Fallback 2: Task parameters
@@ -1459,7 +1467,7 @@ def deploy_model_github(
14591467 "model_id" : best_model_id ,
14601468 "test_accuracy" : float (test_accuracy ),
14611469 "training_task_id" : str (best_task .id ) if best_task else "unknown" ,
1462- "architecture" : model .get_model_design () or {},
1470+ "architecture" : model .config_dict if hasattr ( model , 'config_dict' ) else {},
14631471 "hyperparameters" : hyperparams ,
14641472 "checkpoint_keys" : list (checkpoint .keys ()) if checkpoint else [],
14651473 "input_size" : hyperparams .get ("General/input_size" , {}).get ("value" , 34 ),
@@ -1526,7 +1534,7 @@ def deploy_model_github(
15261534 "weights_file_id" : weights_file_id ,
15271535 "hyperparameters" : hyperparams ,
15281536 "deployment_status" : "deployed" ,
1529- "architecture" : model .get_model_design () or {},
1537+ "architecture" : model .config_dict if hasattr ( model , 'config_dict' ) else {},
15301538 "checkpoint_keys" : list (checkpoint .keys ()) if checkpoint else [],
15311539 "file_size_mb" : os .path .getsize (model_path ) / (1024 * 1024 ),
15321540 "status" : "available" ,
0 commit comments