Skip to content

Commit 4839584

Browse files
committed
Diag print
1 parent b67fb6e commit 4839584

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

fms_fsdp/utils/train_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def train(
8484
for batch_idx, (input, label) in enumerate(train_loader, start=start_step + 1):
8585
if batch_idx > cfg.num_steps:
8686
break
87+
if rank == 0:
88+
print(input.shape)
8789
input = input.to(local_rank)
8890
label = label.to(local_rank)
8991

0 commit comments

Comments
 (0)