Skip to content

Commit 7ac1225

Browse files
authored
remove F8E4M3B11FNUZ from frontend (#2920)
This PR is to remove `F8E4M3B11FNUZ ` in our code. But removing the type can lead to the following failures. ``` test_dot_max_num_imprecise_acc[32-float8e4b15-64-64-64-128-256-256] test_dot_max_num_imprecise_acc[0-float8e4b15-128-256-128-128-256-256] test_dot_max_num_imprecise_acc[32-float8e4b15-128-256-128-128-256-256] test_dot_max_num_imprecise_acc[128-float8e4b15-128-256-128-128-256-256] test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256] test_dot_max_num_imprecise_acc[64-float8e4b15-64-64-64-128-256-256] test_dot_max_num_imprecise_acc[128-float8e4b15-64-64-64-128-256-256] test_typeconvert_upcast[float8e4b15-float16] test_typeconvert_upcast[float8e4b15-float32] ``` Using `Float8E4M3FNUZ` to replace `F8E4M3B11FNUZ`.
1 parent ab32861 commit 7ac1225

File tree

4 files changed

+3
-8
lines changed

4 files changed

+3
-8
lines changed

include/triton/Dialect/Triton/IR/TritonTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TritonTypeDef<string name, string _mnemonic, list<Trait> traits = []>
1515
}
1616

1717
// Floating-point Type
18-
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
18+
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, F16, BF16, F32, F64], "floating-point">;
1919
def TT_FloatTensor : RankedTensorOf<[TT_Float]>;
2020
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
2121

python/src/ir.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -807,11 +807,6 @@ void init_triton_ir(py::module &&m) {
807807
[](TritonOpBuilder &self) -> Type {
808808
return self.getBuilder().getI8Type();
809809
})
810-
.def("get_fp8e4m3b11fnuz_ty",
811-
[](TritonOpBuilder &self) -> Type {
812-
// TODO: align with upstream code to use i8
813-
return self.getBuilder().getType<Float8E4M3B11FNUZType>();
814-
})
815810
.def("get_fp8e5_ty",
816811
[](TritonOpBuilder &self) -> Type {
817812
return self.getBuilder().getType<Float8E5M2Type>();

third_party/intel/language/intel/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def num_warps(_builder=None):
2525

2626
def convert_fp8e4b15_to_float16(arg, _builder):
2727
# Need to bitcast the source first because it's represented as tensor of i8 in MLIR.
28-
tmp_ty = _builder.get_block_ty(_builder.get_fp8e4m3b11fnuz_ty(), arg.type.shape)
28+
tmp_ty = _builder.get_block_ty(_builder.get_fp8e4b8_ty(), arg.type.shape)
2929
tmp = _builder.create_bitcast(arg.handle, tmp_ty)
3030
# Now generate FpToFp op for upcast.
3131
dst_ty = core.block_type(core.float16, arg.type.get_block_shapes())

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ struct FpToFpOpConversion
974974
std::pair<ConverterT, size_t>
975975
getConversionFunc(Type srcTy, Type dstTy,
976976
std::optional<RoundingMode> roundingMode) const {
977-
auto F8E4M3B15TyID = TypeID::get<Float8E4M3B11FNUZType>();
977+
auto F8E4M3B15TyID = TypeID::get<Float8E4M3FNUZType>();
978978
auto F8E4M3TyID = TypeID::get<Float8E4M3FNType>();
979979
auto F8E5M2TyID = TypeID::get<Float8E5M2Type>();
980980
auto F16TyID = TypeID::get<Float16Type>();

0 commit comments

Comments
 (0)