@@ -484,11 +484,13 @@ def _log_image_samples(self, samples, current_step):
484484 def push_to_registry (
485485 self ,
486486 registry_name : str = 'wandb-registry-model' ,
487+ aliases : List [str ] = ['latest' ],
487488 ):
488489 """
489490 Push the model to wandb registry.
490491 Args:
491492 registry_name: Name of the model registry.
493+ aliases: List of aliases for the model.
492494 """
493495 if self .wandb is None :
494496 raise ValueError ("Wandb is not initialized. Cannot push to registry." )
@@ -502,6 +504,7 @@ def push_to_registry(
502504 artifact_or_path = latest_checkpoint_path ,
503505 name = modelname ,
504506 type = "model" ,
507+ aliases = aliases ,
505508 )
506509
507510 target_path = f"{ registry_name } /{ modelname } "
@@ -541,37 +544,49 @@ def __get_best_sweep_runs__(
541544 return best_runs , (min (lower_bound , upper_bound ), max (lower_bound , upper_bound ))
542545
543546 def __compare_run_against_best__ (self , top_k = 2 , metric = "train/best_loss" ):
547+ """
548+ Compare the current run against the best runs from the sweep.
549+ Args:
550+ top_k: Number of top runs to consider.
551+ metric: Metric to compare against.
552+ Returns:
553+ is_good: Whether the current run is among the best.
554+ is_best: Whether the current run is the best.
555+ """
544556 # Get best runs
545557 best_runs , bounds = self .__get_best_sweep_runs__ (metric = metric , top_k = top_k )
546558
547559 # Determine if lower or higher values are better (for loss, lower is better)
548560 is_lower_better = "loss" in metric .lower ()
549561
550562 # Check if current run is one of the best
551- current_run_metric = self .wandb .summary .get (metric , float ('inf' ) if is_lower_better else float ('-inf' ))
552-
553- # Direct check if current run is in best runs
554- for run in best_runs :
555- if run .id == self .wandb .id :
556- print (f"Current run { self .wandb .id } is one of the best runs." )
557- return True
563+ if metric == "train/best_loss" :
564+ current_run_metric = self .best_loss
565+ else :
566+ current_run_metric = self .wandb .summary .get (metric , float ('inf' ) if is_lower_better else float ('-inf' ))
558567
559- # Backup check based on metric value
568+ # Check based on bounds
560569 if (is_lower_better and current_run_metric < bounds [1 ]) or (not is_lower_better and current_run_metric > bounds [0 ]):
561570 print (f"Current run { self .wandb .id } meets performance criteria." )
562- return True
571+ is_best = (is_lower_better and current_run_metric < bounds [0 ]) or (not is_lower_better and current_run_metric > bounds [1 ])
572+ return True , is_best
563573
564- return False
574+ return False , False
565575
566576 def save (self , epoch = 0 , step = 0 , state = None , rngstate = None ):
567577 super ().save (epoch = epoch , step = step , state = state , rngstate = rngstate )
568578
569579 if self .wandb is not None and hasattr (self , "wandb_sweep" ):
570580 checkpoint = get_latest_checkpoint (self .checkpoint_path ())
571581 try :
572- if self .__compare_run_against_best__ (top_k = 5 , metric = "train/best_loss" ):
573- self .push_to_registry ()
574- print ("Model pushed to registry successfully" )
582+ is_good , is_best = self .__compare_run_against_best__ (top_k = 5 , metric = "train/best_loss" )
583+ if is_good :
584+ # Push to registry with appropriate aliases
585+ aliases = ["latest" ]
586+ if is_best :
587+ aliases .append ("best" )
588+ self .push_to_registry (aliases = aliases )
589+ print ("Model pushed to registry successfully with aliases:" , aliases )
575590 else :
576591 print ("Current run is not one of the best runs. Not saving model." )
577592
0 commit comments