Skip to content

Commit 6f4bc33

Browse files
authored
Make grad_output contiguous in cross_entropy.py (#2402)
Signed-off-by: Jack <lityangweiguang@163.com>
1 parent 15dead1 commit 6f4bc33

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

transformer_engine/pytorch/triton/cross_entropy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def cross_entropy_backward(
121121
element_mul_kernel[(n_rows,)](
122122
_input,
123123
_input.stride(-2),
124-
grad_output,
124+
grad_output.contiguous(),
125125
1 if grad_output.numel() > 1 else 0,
126126
V,
127127
BLOCK_SIZE=BLOCK_SIZE,

0 commit comments

Comments
 (0)