@@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
182182 return bitWidth / 8 ;
183183 }
184184
185+ // Handle 8-bit floats.
186+ if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
187+ auto bitWidth = type.getIntOrFloatBitWidth ();
188+ if (bitWidth == 8 )
189+ return bitWidth / 8 ;
190+ return std::nullopt ;
191+ }
192+
185193 if (auto complexType = dyn_cast<ComplexType>(type)) {
186194 auto elementSize = getTypeNumBytes (options, complexType.getElementType ());
187195 if (!elementSize)
@@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
318326 type.getSignedness ());
319327}
320328
329+ // / Converts 8-bit float types to integer types with the same bit width.
330+ // / Returns a nullptr for unsupported 8-bit float types.
331+ static Type convert8BitFloatType (const SPIRVConversionOptions &options,
332+ FloatType type) {
333+ if (!options.emulateUnsupportedFloatTypes )
334+ return nullptr ;
335+ // F8 types are converted to integer types with the same bit width.
336+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338+ Float8E8M0FNUType>(type))
339+ return IntegerType::get (type.getContext (), type.getWidth ());
340+ LLVM_DEBUG (llvm::dbgs () << " unsupported 8-bit float type: " << type << " \n " );
341+ return nullptr ;
342+ }
343+
344+ // / Returns a type with the same shape but with any 8-bit float element type
345+ // / converted to the same bit width integer type. This is a noop when the
346+ // / element type is not the 8-bit float type or emulation flag is set to false.
347+ static ShapedType
348+ convertShaped8BitFloatType (ShapedType type,
349+ const SPIRVConversionOptions &options) {
350+ if (!options.emulateUnsupportedFloatTypes )
351+ return type;
352+ Type srcElementType = type.getElementType ();
353+ Type convertedElementType = nullptr ;
354+ // F8 types are converted to integer types with the same bit width.
355+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357+ Float8E8M0FNUType>(srcElementType))
358+ convertedElementType = IntegerType::get (
359+ type.getContext (), srcElementType.getIntOrFloatBitWidth ());
360+
361+ if (!convertedElementType)
362+ return type;
363+
364+ return type.clone (convertedElementType);
365+ }
366+
321367// / Returns a type with the same shape but with any index element type converted
322368// / to the matching integer type. This is a noop when the element type is not
323369// / the index type.
@@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
337383 const SPIRVConversionOptions &options, VectorType type,
338384 std::optional<spirv::StorageClass> storageClass = {}) {
339385 type = cast<VectorType>(convertIndexElementType (type, options));
386+ type = cast<VectorType>(convertShaped8BitFloatType (type, options));
340387 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType ());
341388 if (!scalarType) {
342389 // If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
433480 }
434481
435482 type = cast<TensorType>(convertIndexElementType (type, options));
483+ type = cast<TensorType>(convertShaped8BitFloatType (type, options));
436484 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType ());
437485 if (!scalarType) {
438486 LLVM_DEBUG (llvm::dbgs ()
@@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
596644 } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
597645 type = cast<MemRefType>(convertIndexElementType (type, options));
598646 arrayElemType = type.getElementType ();
647+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
648+ // Hnadle 8 bit float types.
649+ type = cast<MemRefType>(convertShaped8BitFloatType (type, options));
650+ arrayElemType = type.getElementType ();
599651 } else {
600652 LLVM_DEBUG (
601653 llvm::dbgs ()
@@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
14441496 addConversion ([this ](FloatType floatType) -> std::optional<Type> {
14451497 if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
14461498 return convertScalarType (this ->targetEnv , this ->options , scalarType);
1499+ if (floatType.getWidth () == 8 )
1500+ return convert8BitFloatType (this ->options , floatType);
14471501 return Type ();
14481502 });
14491503
0 commit comments