Skip to content

Commit b516134

Browse files
make measure_start_step an argument.
1 parent 6234a37 commit b516134

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/research_projects/pytorch_xla/train_text_to_image_xla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def run_optimizer(self):
141141

142142
def start_training(self):
143143
dataloader_exception = False
144-
measure_start_step = 10
144+
measure_start_step = args.measure_start_step
145145
assert measure_start_step < self.args.max_train_steps
146146
total_time = 0
147147
for step in range(0, self.args.max_train_steps):
@@ -380,6 +380,7 @@ def parse_args():
380380
default=1,
381381
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
382382
)
383+
parser.add_argument("--measure_start_step", type=int, default=10, help="Step to start profiling.")
383384
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
384385
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
385386
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")

0 commit comments

Comments
 (0)