Skip to content

Commit 9a73dd1

Browse files
committed
Remove hasher hacks
1 parent c6ff80c commit 9a73dd1

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def start_training(self):
206206
def print_loss_closure(step, loss):
207207
print(f"Step: {step}, Loss: {loss}")
208208

209-
if self.args.print_loss:
209+
if True:
210210
xm.add_step_closure(
211211
print_loss_closure,
212212
args=(
@@ -748,16 +748,8 @@ def preprocess_train(examples):
748748
)
749749
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
750750
from datasets.fingerprint import Hasher
751-
# import pdb; pdb.set_trace()
752-
old_batch_size = args.train_batch_size
753-
old_arg = args.output_dir
754-
args.output_dir = '/tmp/trained-model/'
755-
args.train_batch_size=22
756751
new_fingerprint = Hasher.hash(args)
757-
args.train_batch_size=64
758752
new_fingerprint_for_vae = Hasher.hash((args.pretrained_model_name_or_path, args))
759-
args.train_batch_size=old_batch_size
760-
args.output_dir = old_arg
761753
train_dataset_with_embeddings = train_dataset.map(
762754
compute_embeddings_fn, batched=True, batch_size=50, new_fingerprint=new_fingerprint
763755
)

0 commit comments

Comments
 (0)