Skip to content

Commit 6930bf2

Browse files
authored
[FxImporter] Add float8 dtype support to FX importer (#4276)
Adds support in fx_importer for: - torch.float8_e5m2 - torch.float8_e4m3fn - torch.float8_e5m2fnuz - torch.float8_e4m3fnuz
1 parent 2c989a2 commit 6930bf2

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@
220220
}
221221
if ml_dtypes is not None:
222222
TORCH_DTYPE_TO_NPY_TYPE[torch.bfloat16] = ml_dtypes.bfloat16
223+
TORCH_DTYPE_TO_NPY_TYPE[torch.float8_e5m2] = ml_dtypes.float8_e5m2
224+
TORCH_DTYPE_TO_NPY_TYPE[torch.float8_e4m3fn] = ml_dtypes.float8_e4m3fn
225+
TORCH_DTYPE_TO_NPY_TYPE[torch.float8_e5m2fnuz] = ml_dtypes.float8_e5m2fnuz
226+
TORCH_DTYPE_TO_NPY_TYPE[torch.float8_e4m3fnuz] = ml_dtypes.float8_e4m3fnuz
223227

224228
TORCH_DTYPE_TO_INT = {
225229
torch.uint8: 0,

0 commit comments

Comments
 (0)