Skip to content

Commit 6a3d341

Browse files
committed
Handle all Shaped Type 8-bit floats in a similar way.
This approach minimizes the code modification.
1 parent 4bdd204 commit 6a3d341

File tree

1 file changed

+35
-20
lines changed

1 file changed

+35
-20
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,29 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
343343
return nullptr;
344344
}
345345

346+
/// Returns a type with the same shape but with any 8-bit float element type
347+
/// converted to the same bit width integer type. This is a noop when the
348+
/// element type is not the 8-bit float type.
349+
static ShapedType
350+
convertShaped8BitFloatType(ShapedType type,
351+
const SPIRVConversionOptions &options) {
352+
if (!options.emulateUnsupportedFloatTypes)
353+
return nullptr;
354+
auto srcElementType = type.getElementType();
355+
Type convertedElementType = nullptr;
356+
// F8 types are converted to integer types with the same bit width.
357+
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
358+
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
359+
Float8E8M0FNUType>(srcElementType))
360+
convertedElementType = IntegerType::get(
361+
type.getContext(), srcElementType.getIntOrFloatBitWidth());
362+
363+
if (!convertedElementType)
364+
return type;
365+
366+
return type.clone(convertedElementType);
367+
}
368+
346369
/// Converts a sub-byte float ``type` to i32 regardless of target environment.
347370
/// Returns a nullptr for unsupported float types, including non sub-byte
348371
/// types.
@@ -408,22 +431,11 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
408431
const SPIRVConversionOptions &options, VectorType type,
409432
std::optional<spirv::StorageClass> storageClass = {}) {
410433
type = cast<VectorType>(convertIndexElementType(type, options));
434+
type = cast<VectorType>(convertShaped8BitFloatType(type, options));
411435
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
412436
if (!scalarType) {
413-
// If this is not a spec allowed scalar type, there are 2 scenarios,
414-
// 8 bit floats or sub-byte integer types. try to handle them accrodingly.
415-
416-
// Hnadle 8 bit float types.
417-
auto floatType = dyn_cast<FloatType>(type.getElementType());
418-
if (floatType && floatType.getWidth() == 8) {
419-
// If this is an 8 bit float type, try to convert it to a supported
420-
// integer type.
421-
if (auto convertedType = convert8BitFloatType(options, floatType)) {
422-
return VectorType::get(type.getShape(), convertedType);
423-
}
424-
}
425-
426-
// Handle sub-byte integer types.
437+
// If this is not a spec allowed scalar type, try to handle sub-byte integer
438+
// types.
427439
auto intType = dyn_cast<IntegerType>(type.getElementType());
428440
if (!intType) {
429441
LLVM_DEBUG(llvm::dbgs()
@@ -516,6 +528,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
516528
}
517529

518530
type = cast<TensorType>(convertIndexElementType(type, options));
531+
type = cast<TensorType>(convertShaped8BitFloatType(type, options));
519532
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
520533
if (!scalarType) {
521534
LLVM_DEBUG(llvm::dbgs()
@@ -681,12 +694,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
681694
arrayElemType = type.getElementType();
682695
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
683696
// Hnadle 8 bit float types.
684-
if (options.emulateUnsupportedFloatTypes && floatType &&
685-
floatType.getWidth() == 8) {
686-
// If this is an 8 bit float type, try to convert it to a supported
687-
// integer type.
688-
arrayElemType = convert8BitFloatType(options, floatType);
689-
}
697+
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
698+
arrayElemType = type.getElementType();
699+
// if (options.emulateUnsupportedFloatTypes && floatType &&
700+
// floatType.getWidth() == 8) {
701+
// // If this is an 8 bit float type, try to convert it to a supported
702+
// // integer type.
703+
// arrayElemType = convert8BitFloatType(options, floatType);
704+
// }
690705
} else {
691706
LLVM_DEBUG(
692707
llvm::dbgs()

0 commit comments

Comments
 (0)