Skip to content

Commit ad43cae

Browse files
Add hyperparameter cleanup functionality to keep only best model files and standardize filenames
1 parent 58af8ac commit ad43cae

File tree

1 file changed

+108
-1
lines changed

1 file changed

+108
-1
lines changed

Guardian_pipeline_github.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
from datetime import datetime
1212
import ssl
13+
import glob
1314

1415
# Configure logging
1516
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(module)s - %(message)s')
@@ -772,7 +773,9 @@ def make_torch_dataset_for_loader(split_data, split_labels, use_noise=False):
772773

773774
# Save hyperparameters as JSON file
774775
hyperparams_filename = f"hyperparams_{task.id}.json"
776+
standard_hyperparams_filename = "train_hyperparams.json"
775777
hyperparams_filepath = os.path.join(os.getcwd(), hyperparams_filename)
778+
standard_hyperparams_filepath = os.path.join(os.getcwd(), standard_hyperparams_filename)
776779

777780
hyperparams_data = {
778781
"model_id": task.id,
@@ -808,10 +811,16 @@ def make_torch_dataset_for_loader(split_data, split_labels, use_noise=False):
808811
}
809812

810813
try:
814+
# Save with task ID (for ClearML tracking)
811815
with open(hyperparams_filepath, 'w') as f:
812816
json.dump(hyperparams_data, f, indent=2)
813817
print(f"💾 Hyperparameters saved to {hyperparams_filepath}")
814818

819+
# Also save with a standard name for easy reference
820+
with open(standard_hyperparams_filepath, 'w') as f:
821+
json.dump(hyperparams_data, f, indent=2)
822+
print(f"💾 Hyperparameters also saved to {standard_hyperparams_filepath}")
823+
815824
# Upload hyperparameters as artifact
816825
task.upload_artifact(
817826
name="training_hyperparameters",
@@ -824,8 +833,20 @@ def make_torch_dataset_for_loader(split_data, split_labels, use_noise=False):
824833
)
825834
print(f"📤 Hyperparameters uploaded as ClearML artifact")
826835

836+
# Clean up any other hyperparameter files (optional during training)
837+
try:
838+
import glob
839+
for hp_file in glob.glob("hyperparams_*.json"):
840+
# Skip the current hyperparams file
841+
if hp_file == hyperparams_filename:
842+
continue
843+
os.remove(hp_file)
844+
print(f"🧹 Removed old hyperparameter file: {hp_file}")
845+
except Exception as cleanup_error:
846+
print(f"⚠️ Error during hyperparameter cleanup: {cleanup_error}")
847+
827848
except Exception as hyperparams_error:
828-
print(f"⚠️ Error saving hyperparameters: {hyperparams_error}")
849+
print(f"⚠️ Error saving hyperparameters: {hyperparams_error}")
829850

830851
# Publish model
831852
output_model = OutputModel(task=task, name="BiLSTM_ActionRecognition", framework="PyTorch")
@@ -1020,6 +1041,23 @@ def bilstm_hyperparam_optimizer_github(
10201041
}
10211042
)
10221043
print(f"📤 Best hyperparameters uploaded as ClearML artifact")
1044+
1045+
# Clean up any other hyperparameter files to keep only the best one
1046+
try:
1047+
# Find and remove other hyperparameter files in the current directory
1048+
import glob
1049+
for hp_file in glob.glob("hyperparams_*.json"):
1050+
# Skip the best hyperparams file
1051+
if hp_file == best_hyperparams_filename:
1052+
continue
1053+
try:
1054+
os.remove(hp_file)
1055+
print(f"🧹 Removed unnecessary hyperparameter file: {hp_file}")
1056+
except Exception as e:
1057+
print(f"⚠️ Could not remove {hp_file}: {e}")
1058+
except Exception as cleanup_error:
1059+
print(f"⚠️ Error during hyperparameter cleanup: {cleanup_error}")
1060+
10231061
except Exception as hyperparams_error:
10241062
print(f"⚠️ Error saving best hyperparameters: {hyperparams_error}")
10251063

@@ -2006,6 +2044,15 @@ def guardian_github_pipeline():
20062044
accuracy_value = float(test_accuracy) if hasattr(test_accuracy, '__float__') else test_accuracy
20072045
logging.info(f"Evaluation completed. Test accuracy: {accuracy_value:.2f}%")
20082046

2047+
# Clean up hyperparameter files before deployment - keep only the best model's hyperparameters
2048+
logging.info("Cleaning up hyperparameter files...")
2049+
try:
2050+
deleted_count = cleanup_hyperparameter_files(best_task_id)
2051+
logging.info(f"Removed {deleted_count} unnecessary hyperparameter files")
2052+
except Exception as cleanup_error:
2053+
logging.error(f"Error during hyperparameter cleanup: {cleanup_error}")
2054+
logging.info("Continuing with deployment despite cleanup error")
2055+
20092056
# Step 6: Deploy model if it meets threshold
20102057
logging.info("Starting model deployment...")
20112058
try:
@@ -2093,6 +2140,66 @@ def safe_extract_hyperparameter_value(param_value, default_value):
20932140

20942141
return all_passed
20952142

2143+
# ============================================================================
2144+
# CLEAN UP UTILITY
2145+
# ============================================================================
2146+
2147+
def cleanup_hyperparameter_files(best_task_id):
2148+
"""
2149+
Remove all hyperparameter files except the one corresponding to the best model.
2150+
This ensures the artifacts folder stays clean.
2151+
2152+
Args:
2153+
best_task_id (str): The ID of the best task/model to keep
2154+
"""
2155+
import os
2156+
import glob
2157+
2158+
# Define patterns for hyperparam files
2159+
patterns = [
2160+
"hyperparams_*.json",
2161+
"best_hyperparams_*.json"
2162+
]
2163+
2164+
# The file we want to keep (based on best_task_id)
2165+
best_hyperparams_filename = f"best_hyperparams_{best_task_id}.json"
2166+
hyperparams_filename = f"hyperparams_{best_task_id}.json"
2167+
2168+
# Rename the best hyperparams file to a standard name if it exists
2169+
if os.path.exists(best_hyperparams_filename):
2170+
try:
2171+
os.rename(best_hyperparams_filename, "best_hyperparams.json")
2172+
print(f"✅ Renamed {best_hyperparams_filename} to best_hyperparams.json")
2173+
except Exception as e:
2174+
print(f"⚠️ Could not rename {best_hyperparams_filename}: {e}")
2175+
2176+
# Rename regular hyperparams file if it exists
2177+
if os.path.exists(hyperparams_filename):
2178+
try:
2179+
os.rename(hyperparams_filename, "model_hyperparams.json")
2180+
print(f"✅ Renamed {hyperparams_filename} to model_hyperparams.json")
2181+
except Exception as e:
2182+
print(f"⚠️ Could not rename {hyperparams_filename}: {e}")
2183+
2184+
# Count deleted files
2185+
deleted_count = 0
2186+
2187+
# Find and delete all other hyperparameter files
2188+
for pattern in patterns:
2189+
for file_path in glob.glob(pattern):
2190+
# Skip the files we want to keep
2191+
if file_path == best_hyperparams_filename or file_path == hyperparams_filename:
2192+
continue
2193+
2194+
try:
2195+
os.remove(file_path)
2196+
deleted_count += 1
2197+
except Exception as e:
2198+
print(f"⚠️ Could not delete {file_path}: {e}")
2199+
2200+
print(f"🧹 Cleanup completed: Removed {deleted_count} unnecessary hyperparameter files")
2201+
return deleted_count
2202+
20962203
# ============================================================================
20972204
# MAIN EXECUTION
20982205
# ============================================================================

0 commit comments

Comments
 (0)