Skip to content

Commit c6f12e1

Browse files
committed
Review suggestion from @greptile-apps
Signed-off-by: Tim Moon <tmoon@nvidia.com>
1 parent f05fd06 commit c6f12e1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

transformer_engine/pytorch/distributed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,8 @@ def _start_all_gather_fp8_blockwise(
10791079
if quantizer is None or not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
10801080
out = torch.empty(out_shape, dtype=dtype, device=device)
10811081
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
1082-
out = quantizer(out)
1082+
if quantizer is not None:
1083+
out = quantizer(out)
10831084
return out, None
10841085

10851086
# Quantize input tensor if needed

0 commit comments

Comments
 (0)