Skip to content

Commit a4d8c5d

Browse files
committed
remove F8E4M3B11FNUZ from frontend
1 parent 78c13a5 commit a4d8c5d

File tree

3 files changed

+2
-7
lines changed

3 files changed

+2
-7
lines changed

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
1515
}
1616

1717
// Floating-point Type
18-
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
18+
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
1919
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
2020
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
2121

python/src/ir.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -818,11 +818,6 @@ void init_triton_ir(py::module &&m) {
818818
[](TritonOpBuilder &self) -> Type {
819819
return self.getBuilder().getI8Type();
820820
})
821-
.def("get_fp8e4m3b11fnuz_ty",
822-
[](TritonOpBuilder &self) -> Type {
823-
// TODO: align with upstream code to use i8
824-
return self.getBuilder().getType<Float8E4M3B11FNUZType>();
825-
})
826821
.def("get_fp8e5_ty",
827822
[](TritonOpBuilder &self) -> Type {
828823
return self.getBuilder().getType<Float8E5M2Type>();

third_party/intel/language/intel/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def num_warps(_builder=None):
2525

2626
def convert_fp8e4b15_to_float16(arg, _builder):
2727
# Need to bitcast the source first because it's represented as tensor of i8 in MLIR.
28-
tmp_ty = _builder.get_block_ty(_builder.get_fp8e4m3b11fnuz_ty(), arg.type.shape)
28+
tmp_ty = _builder.get_block_ty(_builder.get_fp8e4nv_ty(), arg.type.shape)
2929
tmp = _builder.create_bitcast(arg.handle, tmp_ty)
3030
# Now generate FpToFp op for upcast.
3131
dst_ty = core.block_type(core.float16, arg.type.get_block_shapes())

0 commit comments

Comments
 (0)