Skip to content

Commit afa82d9

Browse files
authored
fix output dtype for nvfuserex cumsum (#2580)
1 parent 4cc9705 commit afa82d9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3268,11 +3268,12 @@ def cumsum_transform(
32683268
mask = fd.ops.triu(mask)
32693269

32703270
out = fd.ops.matmul(nv_a, mask)
3271-
out = fd.ops.cast(out, out_dtype)
32723271
else:
32733272
out = fd.ops.cast(nv_a, out_dtype)
32743273
if a.ndim >= 1:
32753274
out = fd.ops.cumsum(out, dim)
3275+
# restore output dtype in case nvfuser cumsum does implicit type promotion
3276+
out = fd.ops.cast(out, out_dtype)
32763277
return out
32773278

32783279

0 commit comments

Comments
 (0)