Skip to content

Commit c988548

Browse files
[PyTorch] Fix garbage initialized permuted_scale (#2547)
Signed-off-by: xiaoxi-wangfj <[email protected]> Co-authored-by: Teddy Do <[email protected]>
1 parent 27dc83b commit c988548

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

transformer_engine/pytorch/triton/permutation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def permute_with_mask_map(
165165
alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None
166166
)
167167
permuted_scale = (
168-
torch.empty((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
168+
alloc((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
169169
if scale is not None
170170
else None
171171
)

0 commit comments

Comments
 (0)