Skip to content
14 changes: 14 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6229,6 +6229,20 @@ def cumsum_sample_generator(op, device, dtype, requires_grad, **kwargs):
dtypes=(datatypes.float16,),
devicetypes=(devices.DeviceType.CPU,),
),
# nvfuserex follows pytorch convention to run cumsum in reduced
# precision, this causes opinfo tests numerical mismatch for bf16/fp16
# NOTE: Even though both nvfuserex and torch uses reduced precision
# math, because the reduction order is not the same due to
# implementation, error would accumulate.
DecorateInfo(
custom_comparator(partial(assert_close, atol=1e-1, rtol=1e-1)),
"test_core_vs_torch_consistency",
dtypes=(
datatypes.bfloat16,
datatypes.float16,
),
executors=("nvfuser",),
),
),
)
reduction_ops.append(cumsum_opinfo)
Expand Down
Loading