@@ -3227,16 +3227,12 @@ def _grouped_mm_transform(
3227
3227
3228
3228
3229
3229
def _cumsum_check (a : TensorProxy , dim : int , / , dtype : dtypes .dtype | None = None ) -> bool :
3230
- if a .ndim != 1 :
3230
+ if nvfuser_version () < LooseVersion ( "0.2.33" ) and a .ndim != 1 :
3231
3231
return False
3232
3232
3233
3233
return is_supported_tensor (a )
3234
3234
3235
3235
3236
- # Emulate cumsum using matmul: cumsum(a) = a @ triu(ones)
3237
- #
3238
- # This is suboptimal. Revisit this after nvFuser has a scan-based cumsum
3239
- # implementation.
3240
3236
def cumsum_transform (
3241
3237
a : TensorProxy ,
3242
3238
dim : int ,
@@ -3248,26 +3244,31 @@ def cumsum_transform(
3248
3244
fd : FusionDefinition ,
3249
3245
lc_to_nv_map : dict ,
3250
3246
) -> TensorProxy :
3251
- if dtypes .is_integer_dtype (a .dtype ):
3252
- # torch.matmul can't do integers on GPU so we convert `a` to
3253
- # float.
3254
- compute_dtype = DataType .Float
3255
- else :
3256
- compute_dtype = lcdtype_to_nvdtype (a .dtype )
3257
-
3258
3247
if dtype is None :
3259
3248
out_dtype = lcdtype_to_nvdtype (a .dtype if a .dtype not in dtypes .integer_dtypes else dtypes .int64 )
3260
3249
else :
3261
3250
out_dtype = lcdtype_to_nvdtype (dtypes .to_dtype (dtype ))
3262
3251
3263
3252
nv_a = getnv (a , fd , lc_to_nv_map )
3264
- nv_a = fd .ops .cast (nv_a , compute_dtype )
3265
3253
3266
- mask = fd .ops .full ((a .numel , a .numel ), fd .define_scalar (1 ), compute_dtype )
3267
- mask = fd .ops .triu (mask )
3254
+ if nvfuser_version () < LooseVersion ("0.2.33" ):
3255
+ if dtypes .is_integer_dtype (a .dtype ):
3256
+ # torch.matmul can't do integers on GPU so we convert `a` to
3257
+ # float.
3258
+ compute_dtype = DataType .Float
3259
+ else :
3260
+ compute_dtype = lcdtype_to_nvdtype (a .dtype )
3261
+ nv_a = fd .ops .cast (nv_a , compute_dtype )
3262
+
3263
+ mask = fd .ops .full ((a .numel , a .numel ), fd .define_scalar (1 ), compute_dtype )
3264
+ mask = fd .ops .triu (mask )
3268
3265
3269
- out = fd .ops .matmul (nv_a , mask )
3270
- out = fd .ops .cast (out , out_dtype )
3266
+ out = fd .ops .matmul (nv_a , mask )
3267
+ out = fd .ops .cast (out , out_dtype )
3268
+ else :
3269
+ out = fd .ops .cast (nv_a , out_dtype )
3270
+ if a .ndim >= 1 :
3271
+ out = fd .ops .cumsum (out , dim )
3271
3272
return out
3272
3273
3273
3274
0 commit comments