Skip to content

Commit 65ee070

Browse files
committed
feat: push best alias as well
1 parent 4e1b2fc commit 65ee070

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

flaxdiff/data/sources/av_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import subprocess
88
import numpy as np
99
from typing import Tuple, Optional, Union, List
10-
from video_reader import PyVideoReader
1110
from .audio_utils import read_audio
1211

1312
def get_video_fps(video_path: str):
@@ -113,6 +112,7 @@ def read_av_improved(
113112
Returns:
114113
Tuple of (audio_data, video_frames) where video_frames is a numpy array.
115114
"""
115+
from video_reader import PyVideoReader
116116
# Calculate time information for audio extraction
117117
start_time = start / fps if start > 0 else 0
118118
duration = None

flaxdiff/trainer/general_diffusion_trainer.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -484,11 +484,13 @@ def _log_image_samples(self, samples, current_step):
484484
def push_to_registry(
485485
self,
486486
registry_name: str = 'wandb-registry-model',
487+
aliases: List[str] = ['latest'],
487488
):
488489
"""
489490
Push the model to wandb registry.
490491
Args:
491492
registry_name: Name of the model registry.
493+
aliases: List of aliases for the model.
492494
"""
493495
if self.wandb is None:
494496
raise ValueError("Wandb is not initialized. Cannot push to registry.")
@@ -502,6 +504,7 @@ def push_to_registry(
502504
artifact_or_path=latest_checkpoint_path,
503505
name=modelname,
504506
type="model",
507+
aliases=aliases,
505508
)
506509

507510
target_path = f"{registry_name}/{modelname}"
@@ -541,37 +544,49 @@ def __get_best_sweep_runs__(
541544
return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
542545

543546
def __compare_run_against_best__(self, top_k=2, metric="train/best_loss"):
547+
"""
548+
Compare the current run against the best runs from the sweep.
549+
Args:
550+
top_k: Number of top runs to consider.
551+
metric: Metric to compare against.
552+
Returns:
553+
is_good: Whether the current run is among the best.
554+
is_best: Whether the current run is the best.
555+
"""
544556
# Get best runs
545557
best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
546558

547559
# Determine if lower or higher values are better (for loss, lower is better)
548560
is_lower_better = "loss" in metric.lower()
549561

550562
# Check if current run is one of the best
551-
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
552-
553-
# Direct check if current run is in best runs
554-
for run in best_runs:
555-
if run.id == self.wandb.id:
556-
print(f"Current run {self.wandb.id} is one of the best runs.")
557-
return True
563+
if metric == "train/best_loss":
564+
current_run_metric = self.best_loss
565+
else:
566+
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
558567

559-
# Backup check based on metric value
568+
# Check based on bounds
560569
if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
561570
print(f"Current run {self.wandb.id} meets performance criteria.")
562-
return True
571+
is_best = (is_lower_better and current_run_metric < bounds[0]) or (not is_lower_better and current_run_metric > bounds[1])
572+
return True, is_best
563573

564-
return False
574+
return False, False
565575

566576
def save(self, epoch=0, step=0, state=None, rngstate=None):
567577
super().save(epoch=epoch, step=step, state=state, rngstate=rngstate)
568578

569579
if self.wandb is not None and hasattr(self, "wandb_sweep"):
570580
checkpoint = get_latest_checkpoint(self.checkpoint_path())
571581
try:
572-
if self.__compare_run_against_best__(top_k=5, metric="train/best_loss"):
573-
self.push_to_registry()
574-
print("Model pushed to registry successfully")
582+
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss")
583+
if is_good:
584+
# Push to registry with appropriate aliases
585+
aliases = ["latest"]
586+
if is_best:
587+
aliases.append("best")
588+
self.push_to_registry(aliases=aliases)
589+
print("Model pushed to registry successfully with aliases:", aliases)
575590
else:
576591
print("Current run is not one of the best runs. Not saving model.")
577592

0 commit comments

Comments
 (0)