diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index f51e303380..4b6b59d588 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -122,6 +122,14 @@ } ) +if nvfuser_version() >= LooseVersion("0.2.28"): + _lcdtype_to_nvdtype_map.update( + { + dtypes.float4_e2m1fn_x2: DataType.Float4_e2m1fn_x2, + dtypes.float4_e2m1fn_x2_: DataType.Float4_e2m1fn_x2, + } + ) + _lcfp8_to_nvfp8_map: dict[dtypes.dtype, DataType] = { dtypes.float8_e5m2: DataType.Float8_e5m2, dtypes.float8_e5m2_: DataType.Float8_e5m2,