Skip to content

Commit 91cd305

Browse files
authored
Merge pull request #9 from AshishKumar4/feat/datapipeline-refactor
fix: fixed alias for linking
2 parents 9c995d6 + 10de123 commit 91cd305

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

flaxdiff/trainer/general_diffusion_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _log_image_samples(self, samples, current_step):
484484
def push_to_registry(
485485
self,
486486
registry_name: str = 'wandb-registry-model',
487-
aliases: List[str] = ['latest'],
487+
aliases: List[str] = [],
488488
):
489489
"""
490490
Push the model to wandb registry.
@@ -504,14 +504,15 @@ def push_to_registry(
504504
artifact_or_path=latest_checkpoint_path,
505505
name=modelname,
506506
type="model",
507-
aliases=aliases,
507+
aliases=['latest'] + aliases,
508508
)
509509

510510
target_path = f"{registry_name}/{modelname}"
511511

512512
self.wandb.link_artifact(
513513
artifact=logged_artifact,
514514
target_path=target_path,
515+
aliases=aliases,
515516
)
516517
print(f"Model pushed to registry at {target_path}")
517518
return logged_artifact
@@ -582,7 +583,7 @@ def save(self, epoch=0, step=0, state=None, rngstate=None):
582583
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss")
583584
if is_good:
584585
# Push to registry with appropriate aliases
585-
aliases = ["latest"]
586+
aliases = []
586587
if is_best:
587588
aliases.append("best")
588589
self.push_to_registry(aliases=aliases)

0 commit comments

Comments
 (0)