Skip to content

Commit 0bdefba

Browse files
adds missing arguments for Flux.
1 parent 68ff594 commit 0bdefba

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_flux.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def start_training(self):
152152
for step in range(0, self.args.max_train_steps):
153153
print("step: ", step)
154154
batch = next(self.dataloader)
155-
breakpoint()
156155
if step == measure_start_step:
157156
if PROFILE_DIR is not None:
158157
xm.wait_device_ops()
@@ -164,22 +163,22 @@ def start_training(self):
164163
def print_loss_closure(step, loss):
165164
print(f"Step: {step}, Loss: {loss}")
166165

167-
# if self.args.print_loss:
168-
# xm.add_step_closure(
169-
# print_loss_closure,
170-
# args=(
171-
# self.global_step,
172-
# loss,
173-
# ),
174-
# )
175-
# xm.mark_step()
176-
# if not dataloader_exception:
177-
# xm.wait_device_ops()
178-
# total_time = time.time() - last_time
179-
# print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
180-
# else:
181-
# print("dataloader exception happen, skip result")
182-
# return
166+
if self.args.print_loss:
167+
xm.add_step_closure(
168+
print_loss_closure,
169+
args=(
170+
self.global_step,
171+
loss,
172+
),
173+
)
174+
xm.mark_step()
175+
if not dataloader_exception:
176+
xm.wait_device_ops()
177+
total_time = time.time() - last_time
178+
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
179+
else:
180+
print("dataloader exception happen, skip result")
181+
return
183182
def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
184183
sigmas = self.noise_scheduler_copy.sigmas.to(device=self.device, dtype=dtype)
185184
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(self.device)
@@ -307,6 +306,24 @@ def parse_args():
307306
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
308307
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
309308
)
309+
parser.add_argument(
310+
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
311+
)
312+
parser.add_argument(
313+
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
314+
)
315+
parser.add_argument(
316+
"--mode_scale",
317+
type=float,
318+
default=1.29,
319+
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
320+
)
321+
parser.add_argument(
322+
"--guidance_scale",
323+
type=float,
324+
default=3.5,
325+
help="the FLUX.1 dev variant is a guidance distilled model",
326+
)
310327
parser.add_argument(
311328
"--revision",
312329
type=str,
@@ -793,7 +810,7 @@ def preprocess_train(examples):
793810
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
794811
)
795812
train_dataset_with_tensors = train_dataset.map(
796-
pixels_to_tensors_fn, batched=True, new_fingerprint=new_fingerprint_two, batch_size=64
813+
pixels_to_tensors_fn, batched=True, new_fingerprint=new_fingerprint_two, batch_size=256
797814
)
798815
precomputed_dataset = concatenate_datasets(
799816
[train_dataset_with_embeddings, train_dataset_with_tensors.remove_columns(["text", "image"])], axis=1

0 commit comments

Comments
 (0)