Skip to content

Commit 7693fd9

Browse files
authored
torch.cumsum api change (#2507)
1 parent ce17a92 commit 7693fd9

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3227,16 +3227,12 @@ def _grouped_mm_transform(
32273227

32283228

32293229
def _cumsum_check(a: TensorProxy, dim: int, /, dtype: dtypes.dtype | None = None) -> bool:
3230-
if a.ndim != 1:
3230+
if nvfuser_version() < LooseVersion("0.2.33") and a.ndim != 1:
32313231
return False
32323232

32333233
return is_supported_tensor(a)
32343234

32353235

3236-
# Emulate cumsum using matmul: cumsum(a) = a @ triu(ones)
3237-
#
3238-
# This is suboptimal. Revisit this after nvFuser has a scan-based cumsum
3239-
# implementation.
32403236
def cumsum_transform(
32413237
a: TensorProxy,
32423238
dim: int,
@@ -3248,26 +3244,31 @@ def cumsum_transform(
32483244
fd: FusionDefinition,
32493245
lc_to_nv_map: dict,
32503246
) -> TensorProxy:
3251-
if dtypes.is_integer_dtype(a.dtype):
3252-
# torch.matmul can't do integers on GPU so we convert `a` to
3253-
# float.
3254-
compute_dtype = DataType.Float
3255-
else:
3256-
compute_dtype = lcdtype_to_nvdtype(a.dtype)
3257-
32583247
if dtype is None:
32593248
out_dtype = lcdtype_to_nvdtype(a.dtype if a.dtype not in dtypes.integer_dtypes else dtypes.int64)
32603249
else:
32613250
out_dtype = lcdtype_to_nvdtype(dtypes.to_dtype(dtype))
32623251

32633252
nv_a = getnv(a, fd, lc_to_nv_map)
3264-
nv_a = fd.ops.cast(nv_a, compute_dtype)
32653253

3266-
mask = fd.ops.full((a.numel, a.numel), fd.define_scalar(1), compute_dtype)
3267-
mask = fd.ops.triu(mask)
3254+
if nvfuser_version() < LooseVersion("0.2.33"):
3255+
if dtypes.is_integer_dtype(a.dtype):
3256+
# torch.matmul can't do integers on GPU so we convert `a` to
3257+
# float.
3258+
compute_dtype = DataType.Float
3259+
else:
3260+
compute_dtype = lcdtype_to_nvdtype(a.dtype)
3261+
nv_a = fd.ops.cast(nv_a, compute_dtype)
3262+
3263+
mask = fd.ops.full((a.numel, a.numel), fd.define_scalar(1), compute_dtype)
3264+
mask = fd.ops.triu(mask)
32683265

3269-
out = fd.ops.matmul(nv_a, mask)
3270-
out = fd.ops.cast(out, out_dtype)
3266+
out = fd.ops.matmul(nv_a, mask)
3267+
out = fd.ops.cast(out, out_dtype)
3268+
else:
3269+
out = fd.ops.cast(nv_a, out_dtype)
3270+
if a.ndim >= 1:
3271+
out = fd.ops.cumsum(out, dim)
32713272
return out
32723273

32733274

0 commit comments

Comments
 (0)