Skip to content

Commit fae564f

Browse files
David ZhengdzhengAP
authored andcommitted
fix(smoothquant): fix NCCL timeout in DDP example sample generation
dispatch_model and generate require all ranks to participate. Add dist.barrier() before generation and only log output on rank 0. Signed-off-by: David Zheng <dzheng@apple.com>
1 parent 10daccc commit fae564f

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

examples/quantization_w8a8_int8/smoothquant_ddp_example.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,14 @@ def tokenize(sample):
115115
# ---------------------------------------------------------------------------
116116
# Sample generation (rank 0 only)
117117
# ---------------------------------------------------------------------------
118+
# Sample generation (all ranks must participate)
119+
dist.barrier()
120+
dispatch_model(model)
121+
sample = tokenizer("Hello my name is", return_tensors="pt")
122+
sample = {k: v.to(model.device) for k, v in sample.items()}
123+
output = model.generate(**sample, max_new_tokens=50)
118124
if rank == 0:
119125
logger.info("\n========== SAMPLE GENERATION ==========")
120-
dispatch_model(model)
121-
sample = tokenizer("Hello my name is", return_tensors="pt")
122-
sample = {k: v.to(model.device) for k, v in sample.items()}
123-
output = model.generate(**sample, max_new_tokens=50)
124126
logger.info(tokenizer.decode(output[0]))
125127
logger.info("========================================\n")
126128

0 commit comments

Comments
 (0)