@@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
169169// SPIR-V dialect. Keeping it local till the use case arises.
170170static std::optional<int64_t >
171171getTypeNumBytes (const SPIRVConversionOptions &options, Type type) {
172+
172173 if (isa<spirv::ScalarType>(type)) {
173174 auto bitWidth = type.getIntOrFloatBitWidth ();
174175 // According to the SPIR-V spec:
@@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
182183 return bitWidth / 8 ;
183184 }
184185
186+ // Handle 8-bit floats.
187+ if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
188+ auto bitWidth = type.getIntOrFloatBitWidth ();
189+ if (bitWidth == 8 )
190+ return bitWidth / 8 ;
191+ else
192+ return std::nullopt ;
193+ }
194+
185195 if (auto complexType = dyn_cast<ComplexType>(type)) {
186196 auto elementSize = getTypeNumBytes (options, complexType.getElementType ());
187197 if (!elementSize)
@@ -318,6 +328,67 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
318328 type.getSignedness ());
319329}
320330
331+ // / Converts 8-bit float types to integer types with the same bit width.
332+ // / Returns a nullptr for unsupported 8-bit float types.
333+ static Type convert8BitFloatType (const SPIRVConversionOptions &options,
334+ FloatType type) {
335+ if (!options.emulateUnsupportedFloatTypes )
336+ return nullptr ;
337+ // F8 types are converted to integer types with the same bit width.
338+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
339+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
340+ Float8E8M0FNUType>(type))
341+ return IntegerType::get (type.getContext (), type.getWidth ());
342+ LLVM_DEBUG (llvm::dbgs () << " unsupported 8-bit float type\n " );
343+ return nullptr ;
344+ }
345+
346+ // / Converts a sub-byte float ``type` to i32 regardless of target environment.
347+ // / Returns a nullptr for unsupported float types, including non sub-byte
348+ // / types.
349+ // /
350+ // / We are treating 8 bit floats as sub-byte types here due to it's similar
351+ // / nature of being used as a packed format.
352+
353+ // / Note that we don't recognize
354+ // / sub-byte types in `spirv::ScalarType` and use the above given that these
355+ // / sub-byte types are not supported at all in SPIR-V; there are no
356+ // / compute/storage capability for them like other supported integer types.
357+
358+ // static Type convertPackedFLoatType(const SPIRVConversionOptions &options,
359+ // FloatType type) {
360+
361+ // // F4, F6, F8 types are converted to integer types with the same bit width.
362+
363+ // if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
364+ // Float8E5M2FNUZType,
365+ // Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
366+ // Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
367+ // Float8E8M0FNUType>(type))
368+ // auto emulatedType = IntegerType::get(type.getContext(), type.getWidth());
369+
370+ // if (type.getWidth() > 8) {
371+ // LLVM_DEBUG(llvm::dbgs() << "not a packed type\n");
372+ // return nullptr;
373+ // }
374+ // if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
375+ // LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
376+ // return nullptr;
377+ // }
378+
379+ // // if (!llvm::isPowerOf2_32(type.getWidth())) {
380+ // // LLVM_DEBUG(llvm::dbgs()
381+ // // << "unsupported non-power-of-two bitwidth in sub-byte" <<
382+ // type
383+ // // << "\n");
384+ // // return nullptr;
385+ // // }
386+
387+ // LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
388+ // return IntegerType::get(type.getContext(), /*width=*/32,
389+ // type.getSignedness());
390+ // }
391+
321392// / Returns a type with the same shape but with any index element type converted
322393// / to the matching integer type. This is a noop when the element type is not
323394// / the index type.
@@ -339,8 +410,20 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
339410 type = cast<VectorType>(convertIndexElementType (type, options));
340411 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType ());
341412 if (!scalarType) {
342- // If this is not a spec allowed scalar type, try to handle sub-byte integer
343- // types.
413+ // If this is not a spec allowed scalar type, there are 2 scenarios,
414+ // 8 bit floats or sub-byte integer types. try to handle them accrodingly.
415+
416+ // Hnadle 8 bit float types.
417+ auto floatType = dyn_cast<FloatType>(type.getElementType ());
418+ if (floatType && floatType.getWidth () == 8 ) {
419+ // If this is an 8 bit float type, try to convert it to a supported
420+ // integer type.
421+ if (auto convertedType = convert8BitFloatType (options, floatType)) {
422+ return VectorType::get (type.getShape (), convertedType);
423+ }
424+ }
425+
426+ // Handle sub-byte integer types.
344427 auto intType = dyn_cast<IntegerType>(type.getElementType ());
345428 if (!intType) {
346429 LLVM_DEBUG (llvm::dbgs ()
@@ -596,6 +679,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
596679 } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
597680 type = cast<MemRefType>(convertIndexElementType (type, options));
598681 arrayElemType = type.getElementType ();
682+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
683+ // Hnadle 8 bit float types.
684+ if (options.emulateUnsupportedFloatTypes && floatType &&
685+ floatType.getWidth () == 8 ) {
686+ // If this is an 8 bit float type, try to convert it to a supported
687+ // integer type.
688+ arrayElemType = convert8BitFloatType (options, floatType);
689+ }
599690 } else {
600691 LLVM_DEBUG (
601692 llvm::dbgs ()
@@ -1444,6 +1535,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
14441535 addConversion ([this ](FloatType floatType) -> std::optional<Type> {
14451536 if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
14461537 return convertScalarType (this ->targetEnv , this ->options , scalarType);
1538+ if (floatType.getWidth () == 8 )
1539+ return convert8BitFloatType (this ->options , floatType);
14471540 return Type ();
14481541 });
14491542
0 commit comments