Skip to content

Commit a433d4c

Browse files
committed
choose Float8E4M3FNUZType as an equivalent
1 parent a4d8c5d commit a433d4c

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

third_party/intel/language/intel/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def num_warps(_builder=None):
2525

2626
def convert_fp8e4b15_to_float16(arg, _builder):
2727
# Need to bitcast the source first because it's represented as tensor of i8 in MLIR.
28-
tmp_ty = _builder.get_block_ty(_builder.get_fp8e4nv_ty(), arg.type.shape)
28+
tmp_ty = _builder.get_block_ty(_builder.get_fp8e4b8_ty(), arg.type.shape)
2929
tmp = _builder.create_bitcast(arg.handle, tmp_ty)
3030
# Now generate FpToFp op for upcast.
3131
dst_ty = core.block_type(core.float16, arg.type.get_block_shapes())

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ struct FpToFpOpConversion
974974
std::pair<ConverterT, size_t>
975975
getConversionFunc(Type srcTy, Type dstTy,
976976
std::optional<RoundingMode> roundingMode) const {
977-
auto F8E4M3B15TyID = TypeID::get<Float8E4M3B11FNUZType>();
977+
auto F8E4M3B15TyID = TypeID::get<Float8E4M3FNUZType>();
978978
auto F8E4M3TyID = TypeID::get<Float8E4M3FNType>();
979979
auto F8E5M2TyID = TypeID::get<Float8E5M2Type>();
980980
auto F16TyID = TypeID::get<Float16Type>();

0 commit comments

Comments
 (0)