@@ -343,6 +343,29 @@ static Type convert8BitFloatType(const SPIRVConversionOptions &options,
343343 return nullptr ;
344344}
345345
346+ // / Returns a type with the same shape but with any 8-bit float element type
347+ // / converted to the same bit width integer type. This is a noop when the
348+ // / element type is not the 8-bit float type.
349+ static ShapedType
350+ convertShaped8BitFloatType (ShapedType type,
351+ const SPIRVConversionOptions &options) {
352+ if (!options.emulateUnsupportedFloatTypes )
353+ return nullptr ;
354+ auto srcElementType = type.getElementType ();
355+ Type convertedElementType = nullptr ;
356+ // F8 types are converted to integer types with the same bit width.
357+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
358+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
359+ Float8E8M0FNUType>(srcElementType))
360+ convertedElementType = IntegerType::get (
361+ type.getContext (), srcElementType.getIntOrFloatBitWidth ());
362+
363+ if (!convertedElementType)
364+ return type;
365+
366+ return type.clone (convertedElementType);
367+ }
368+
346369// / Converts a sub-byte float ``type` to i32 regardless of target environment.
347370// / Returns a nullptr for unsupported float types, including non sub-byte
348371// / types.
@@ -408,22 +431,11 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
408431 const SPIRVConversionOptions &options, VectorType type,
409432 std::optional<spirv::StorageClass> storageClass = {}) {
410433 type = cast<VectorType>(convertIndexElementType (type, options));
434+ type = cast<VectorType>(convertShaped8BitFloatType (type, options));
411435 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType ());
412436 if (!scalarType) {
413- // If this is not a spec allowed scalar type, there are 2 scenarios,
414- // 8 bit floats or sub-byte integer types. try to handle them accrodingly.
415-
416- // Hnadle 8 bit float types.
417- auto floatType = dyn_cast<FloatType>(type.getElementType ());
418- if (floatType && floatType.getWidth () == 8 ) {
419- // If this is an 8 bit float type, try to convert it to a supported
420- // integer type.
421- if (auto convertedType = convert8BitFloatType (options, floatType)) {
422- return VectorType::get (type.getShape (), convertedType);
423- }
424- }
425-
426- // Handle sub-byte integer types.
437+ // If this is not a spec allowed scalar type, try to handle sub-byte integer
438+ // types.
427439 auto intType = dyn_cast<IntegerType>(type.getElementType ());
428440 if (!intType) {
429441 LLVM_DEBUG (llvm::dbgs ()
@@ -516,6 +528,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
516528 }
517529
518530 type = cast<TensorType>(convertIndexElementType (type, options));
531+ type = cast<TensorType>(convertShaped8BitFloatType (type, options));
519532 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType ());
520533 if (!scalarType) {
521534 LLVM_DEBUG (llvm::dbgs ()
@@ -681,12 +694,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
681694 arrayElemType = type.getElementType ();
682695 } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
683696 // Hnadle 8 bit float types.
684- if (options.emulateUnsupportedFloatTypes && floatType &&
685- floatType.getWidth () == 8 ) {
686- // If this is an 8 bit float type, try to convert it to a supported
687- // integer type.
688- arrayElemType = convert8BitFloatType (options, floatType);
689- }
697+ type = cast<MemRefType>(convertShaped8BitFloatType (type, options));
698+ arrayElemType = type.getElementType ();
699+ // if (options.emulateUnsupportedFloatTypes && floatType &&
700+ // floatType.getWidth() == 8) {
701+ // // If this is an 8 bit float type, try to convert it to a supported
702+ // // integer type.
703+ // arrayElemType = convert8BitFloatType(options, floatType);
704+ // }
690705 } else {
691706 LLVM_DEBUG (
692707 llvm::dbgs ()
0 commit comments