Skip to content

Commit ffc85d5

Browse files
ClearML model attribute update
1 parent 8fdb74b commit ffc85d5

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

Guardian_pipeline_github.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,16 +1202,24 @@ def infer_architecture_from_checkpoint(checkpoint):
12021202
else:
12031203
print("⚠️ Could not infer architecture, using fallback methods...")
12041204

1205-
# Fallback 1: Try model design
1206-
model_design = best_model.get_model_design()
1207-
if model_design:
1208-
print(f"📋 Using model design: {model_design}")
1209-
hidden_size = int(model_design.get('hidden_size', 256))
1210-
num_layers = int(model_design.get('num_layers', 4))
1211-
dropout_rate = float(model_design.get('dropout_rate', 0.1))
1212-
use_layer_norm = bool(model_design.get('use_layer_norm', False))
1213-
attention_dropout = float(model_design.get('attention_dropout', 0.1))
1214-
batch_size = int(model_design.get('batch_size', 32))
1205+
# Fallback 1: Try model config
1206+
try:
1207+
# Try to get model configuration from the model object
1208+
model_config = model.config_dict if hasattr(model, 'config_dict') else {}
1209+
if not model_config and hasattr(model, 'get_model_config'):
1210+
model_config = model.get_model_config() or {}
1211+
except Exception as config_error:
1212+
print(f"⚠️ Could not retrieve model config: {config_error}")
1213+
model_config = {}
1214+
1215+
if model_config:
1216+
print(f"📋 Using model config: {model_config}")
1217+
hidden_size = int(model_config.get('hidden_size', 256))
1218+
num_layers = int(model_config.get('num_layers', 4))
1219+
dropout_rate = float(model_config.get('dropout_rate', 0.1))
1220+
use_layer_norm = bool(model_config.get('use_layer_norm', False))
1221+
attention_dropout = float(model_config.get('attention_dropout', 0.1))
1222+
batch_size = int(model_config.get('batch_size', 32))
12151223
else:
12161224
print("📋 Using task parameters...")
12171225
# Fallback 2: Task parameters
@@ -1459,7 +1467,7 @@ def deploy_model_github(
14591467
"model_id": best_model_id,
14601468
"test_accuracy": float(test_accuracy),
14611469
"training_task_id": str(best_task.id) if best_task else "unknown",
1462-
"architecture": model.get_model_design() or {},
1470+
"architecture": model.config_dict if hasattr(model, 'config_dict') else {},
14631471
"hyperparameters": hyperparams,
14641472
"checkpoint_keys": list(checkpoint.keys()) if checkpoint else [],
14651473
"input_size": hyperparams.get("General/input_size", {}).get("value", 34),
@@ -1526,7 +1534,7 @@ def deploy_model_github(
15261534
"weights_file_id": weights_file_id,
15271535
"hyperparameters": hyperparams,
15281536
"deployment_status": "deployed",
1529-
"architecture": model.get_model_design() or {},
1537+
"architecture": model.config_dict if hasattr(model, 'config_dict') else {},
15301538
"checkpoint_keys": list(checkpoint.keys()) if checkpoint else [],
15311539
"file_size_mb": os.path.getsize(model_path) / (1024 * 1024),
15321540
"status": "available",

0 commit comments

Comments
 (0)