Skip to content

Commit 9093a4b

Browse files
committed
fix filename of embedding
1 parent 9bdb6a1 commit 9093a4b

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
178178

179179
# make sure the state_dict has the correct naming in the parameters.
180180
textual_inversion_state_dict = safetensors.torch.load_file(
181-
os.path.join(tmpdir, f"{os.path(tmpdir).name}_emb.safetensors")
181+
os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
182182
)
183183
is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
184184
self.assertTrue(is_te)

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2405,7 +2405,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
24052405
)
24062406

24072407
if args.train_text_encoder_ti:
2408-
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
2408+
embeddings_path = f"{args.output_dir}/{os.path.basename(args.output_dir)}_emb.safetensors"
24092409
embedding_handler.save_embeddings(embeddings_path)
24102410

24112411
# Final inference

0 commit comments

Comments
 (0)