Skip to content

Commit 4b6d70d

Browse files
committed
Add 8-bit float emulation for SPIR-V conversion.
SPIR-V does not support any 8-bit floats. Threfore, 8-bit floats are emulated as 8-bit integers.
1 parent 254b90f commit 4b6d70d

File tree

6 files changed

+120
-5
lines changed

6 files changed

+120
-5
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
196196
"bool", /*default=*/"true",
197197
"Emulate narrower scalar types with 32-bit ones if not supported by "
198198
"the target">,
199+
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
200+
"bool", /*default=*/"true",
201+
"Emulate unsupported float types by emulating them with integer types of same bit width">
199202
];
200203
}
201204

@@ -416,7 +419,10 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
416419
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
417420
"bool", /*default=*/"true",
418421
"Emulate narrower scalar types with 32-bit ones if not supported by"
419-
" the target">
422+
" the target">,
423+
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
424+
"bool", /*default=*/"true",
425+
"Emulate unsupported float types by emulating them with integer types of same bit width">
420426
];
421427
}
422428

@@ -500,7 +506,10 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
500506
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
501507
"bool", /*default=*/"true",
502508
"Emulate narrower scalar types with 32-bit ones if not supported by"
503-
" the target">
509+
" the target">,
510+
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
511+
"bool", /*default=*/"true",
512+
"Emulate unsupported float types by emulating them with integer types of same bit width">
504513
];
505514
}
506515

@@ -1167,7 +1176,10 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
11671176
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
11681177
"bool", /*default=*/"true",
11691178
"Emulate narrower scalar types with 32-bit ones if not supported by"
1170-
" the target">
1179+
" the target">,
1180+
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
1181+
"bool", /*default=*/"true",
1182+
"Emulate unsupported float types by emulating them with integer types of same bit width">
11711183
];
11721184
}
11731185

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
3939
/// The number of bits to store a boolean value.
4040
unsigned boolNumBits{8};
4141

42+
/// Whether to emulate unsupported floats with integer types of same bit
43+
/// width.
44+
bool emulateUnsupportedFloatTypes{true};
45+
4246
/// How sub-byte values are storaged in memory.
4347
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
4448

mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
4343

4444
SPIRVConversionOptions options;
4545
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
46+
options.emulateUnsupportedFloatTypes =
47+
this->emulateUnsupportedFloatTypes;
4648
SPIRVTypeConverter typeConverter(targetAttr, options);
4749

4850
// TODO: We should also take care of block argument type conversion.

mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
4242

4343
SPIRVConversionOptions options;
4444
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
45+
options.emulateUnsupportedFloatTypes =
46+
this->emulateUnsupportedFloatTypes;
4547
SPIRVTypeConverter typeConverter(targetAttr, options);
4648

4749
RewritePatternSet patterns(context);

mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class ConvertTensorToSPIRVPass
4141

4242
SPIRVConversionOptions options;
4343
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
44+
options.emulateUnsupportedFloatTypes =
45+
this->emulateUnsupportedFloatTypes;
4446
SPIRVTypeConverter typeConverter(targetAttr, options);
4547

4648
RewritePatternSet patterns(context);

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ static spirv::ScalarType getIndexType(MLIRContext *ctx,
169169
// SPIR-V dialect. Keeping it local till the use case arises.
170170
static std::optional<int64_t>
171171
getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
172+
172173
if (isa<spirv::ScalarType>(type)) {
173174
auto bitWidth = type.getIntOrFloatBitWidth();
174175
// According to the SPIR-V spec:
@@ -182,6 +183,15 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
182183
return bitWidth / 8;
183184
}
184185

186+
// Handle 8-bit floats.
187+
if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
188+
auto bitWidth = type.getIntOrFloatBitWidth();
189+
if (bitWidth == 8)
190+
return bitWidth / 8;
191+
else
192+
return std::nullopt;
193+
}
194+
185195
if (auto complexType = dyn_cast<ComplexType>(type)) {
186196
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
187197
if (!elementSize)
@@ -318,6 +328,67 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
318328
type.getSignedness());
319329
}
320330

331+
/// Converts 8-bit float types to integer types with the same bit width.
332+
/// Returns a nullptr for unsupported 8-bit float types.
333+
static Type convert8BitFloatType(const SPIRVConversionOptions &options,
334+
FloatType type) {
335+
if (!options.emulateUnsupportedFloatTypes)
336+
return nullptr;
337+
// F8 types are converted to integer types with the same bit width.
338+
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
339+
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
340+
Float8E8M0FNUType>(type))
341+
return IntegerType::get(type.getContext(), type.getWidth());
342+
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type\n");
343+
return nullptr;
344+
}
345+
346+
/// Converts a sub-byte float ``type` to i32 regardless of target environment.
347+
/// Returns a nullptr for unsupported float types, including non sub-byte
348+
/// types.
349+
///
350+
/// We are treating 8 bit floats as sub-byte types here due to it's similar
351+
/// nature of being used as a packed format.
352+
353+
/// Note that we don't recognize
354+
/// sub-byte types in `spirv::ScalarType` and use the above given that these
355+
/// sub-byte types are not supported at all in SPIR-V; there are no
356+
/// compute/storage capability for them like other supported integer types.
357+
358+
// static Type convertPackedFLoatType(const SPIRVConversionOptions &options,
359+
// FloatType type) {
360+
361+
// // F4, F6, F8 types are converted to integer types with the same bit width.
362+
363+
// if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
364+
// Float8E5M2FNUZType,
365+
// Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
366+
// Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
367+
// Float8E8M0FNUType>(type))
368+
// auto emulatedType = IntegerType::get(type.getContext(), type.getWidth());
369+
370+
// if (type.getWidth() > 8) {
371+
// LLVM_DEBUG(llvm::dbgs() << "not a packed type\n");
372+
// return nullptr;
373+
// }
374+
// if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
375+
// LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
376+
// return nullptr;
377+
// }
378+
379+
// // if (!llvm::isPowerOf2_32(type.getWidth())) {
380+
// // LLVM_DEBUG(llvm::dbgs()
381+
// // << "unsupported non-power-of-two bitwidth in sub-byte" <<
382+
// type
383+
// // << "\n");
384+
// // return nullptr;
385+
// // }
386+
387+
// LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
388+
// return IntegerType::get(type.getContext(), /*width=*/32,
389+
// type.getSignedness());
390+
// }
391+
321392
/// Returns a type with the same shape but with any index element type converted
322393
/// to the matching integer type. This is a noop when the element type is not
323394
/// the index type.
@@ -339,8 +410,20 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
339410
type = cast<VectorType>(convertIndexElementType(type, options));
340411
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
341412
if (!scalarType) {
342-
// If this is not a spec allowed scalar type, try to handle sub-byte integer
343-
// types.
413+
// If this is not a spec allowed scalar type, there are 2 scenarios,
414+
// 8 bit floats or sub-byte integer types. try to handle them accrodingly.
415+
416+
// Hnadle 8 bit float types.
417+
auto floatType = dyn_cast<FloatType>(type.getElementType());
418+
if (floatType && floatType.getWidth() == 8) {
419+
// If this is an 8 bit float type, try to convert it to a supported
420+
// integer type.
421+
if (auto convertedType = convert8BitFloatType(options, floatType)) {
422+
return VectorType::get(type.getShape(), convertedType);
423+
}
424+
}
425+
426+
// Handle sub-byte integer types.
344427
auto intType = dyn_cast<IntegerType>(type.getElementType());
345428
if (!intType) {
346429
LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +679,14 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
596679
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
597680
type = cast<MemRefType>(convertIndexElementType(type, options));
598681
arrayElemType = type.getElementType();
682+
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
683+
// Hnadle 8 bit float types.
684+
if (options.emulateUnsupportedFloatTypes && floatType &&
685+
floatType.getWidth() == 8) {
686+
// If this is an 8 bit float type, try to convert it to a supported
687+
// integer type.
688+
arrayElemType = convert8BitFloatType(options, floatType);
689+
}
599690
} else {
600691
LLVM_DEBUG(
601692
llvm::dbgs()
@@ -1444,6 +1535,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
14441535
addConversion([this](FloatType floatType) -> std::optional<Type> {
14451536
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
14461537
return convertScalarType(this->targetEnv, this->options, scalarType);
1538+
if (floatType.getWidth() == 8)
1539+
return convert8BitFloatType(this->options, floatType);
14471540
return Type();
14481541
});
14491542

0 commit comments

Comments
 (0)