Skip to content

Commit 4248d7c

Browse files
committed
Add Intel downstream changes and address review comments
1 parent 3ed46d2 commit 4248d7c

File tree

5 files changed

+39
-15
lines changed

5 files changed

+39
-15
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4190,11 +4190,12 @@ def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
41904190
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
41914191
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
41924192
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
4193+
def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
41934194
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
41944195
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
4195-
def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>;
4196+
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>;
41964197
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
4197-
[SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>;
4198+
[SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>;
41984199
// Component type check is done in the type parser for the following SPIR-V
41994200
// dialect-specific types so we use "Any" here.
42004201
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
@@ -4217,14 +4218,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
42174218
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
42184219
"any SPIR-V sampled image type">;
42194220

4220-
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>;
4221+
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
42214222
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
42224223
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
42234224
def SPIRV_Composite :
42244225
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
42254226
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
42264227
def SPIRV_Type : AnyTypeOf<[
4227-
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
4228+
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
42284229
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
42294230
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
42304231
]>;

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {
8686

8787
// -----
8888

89-
def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
89+
def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_AnyFloat, []> {
9090
let summary = [{
9191
Convert value numerically from floating point to signed integer, with
9292
round toward 0.0.
@@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float
111111

112112
// -----
113113

114-
def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
114+
def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_AnyFloat, []> {
115115
let summary = [{
116116
Convert value numerically from floating point to unsigned integer, with
117117
round toward 0.0.
@@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
138138
// -----
139139

140140
def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
141-
SPIRV_FloatOrBFloat16,
141+
SPIRV_AnyFloat,
142142
SPIRV_Integer,
143143
[SignedOp]> {
144144
let summary = [{
@@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
165165
// -----
166166

167167
def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
168-
SPIRV_FloatOrBFloat16,
168+
SPIRV_AnyFloat,
169169
SPIRV_Integer,
170170
[UnsignedOp]> {
171171
let summary = [{
@@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
192192
// -----
193193

194194
def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
195-
SPIRV_FloatOrBFloat16,
196-
SPIRV_FloatOrBFloat16,
195+
SPIRV_AnyFloat,
196+
SPIRV_AnyFloat,
197197
[UsableInSpecConstantOp]> {
198198
let summary = [{
199199
Convert value numerically from one floating-point width to another

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,12 @@ bool ScalarType::isValid(IntegerType type) {
514514

515515
void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
516516
std::optional<StorageClass> storage) {
517+
if (llvm::isa<BFloat16Type>(*this)) {
518+
static const Extension exts[] = {Extension::SPV_KHR_bfloat16};
519+
ArrayRef<Extension> ref(exts, std::size(exts));
520+
extensions.push_back(ref);
521+
}
522+
517523
// 8- or 16-bit integer/floating-point numbers will require extra extensions
518524
// to appear in interface storage classes. See SPV_KHR_16bit_storage and
519525
// SPV_KHR_8bit_storage for more details.
@@ -532,7 +538,7 @@ void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
532538
[[fallthrough]];
533539
case StorageClass::Input:
534540
case StorageClass::Output:
535-
if (getIntOrFloatBitWidth() == 16) {
541+
if (getIntOrFloatBitWidth() == 16 && !llvm::isa<BFloat16Type>(*this)) {
536542
static const Extension exts[] = {Extension::SPV_KHR_16bit_storage};
537543
ArrayRef<Extension> ref(exts, std::size(exts));
538544
extensions.push_back(ref);
@@ -619,7 +625,19 @@ void ScalarType::getCapabilities(
619625
} else {
620626
assert(llvm::isa<FloatType>(*this));
621627
switch (bitwidth) {
622-
WIDTH_CASE(Float, 16);
628+
case 16: {
629+
if (llvm::isa<BFloat16Type>(*this)) {
630+
static const Capability caps[] = {Capability::BFloat16TypeKHR};
631+
ArrayRef<Capability> ref(caps, std::size(caps));
632+
capabilities.push_back(ref);
633+
634+
} else {
635+
static const Capability caps[] = {Capability::Float16};
636+
ArrayRef<Capability> ref(caps, std::size(caps));
637+
capabilities.push_back(ref);
638+
}
639+
break;
640+
}
623641
WIDTH_CASE(Float, 64);
624642
case 32:
625643
break;

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,9 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
14131413
} else if (floatType.isF16()) {
14141414
APInt data(16, operands[2]);
14151415
value = APFloat(APFloat::IEEEhalf(), data);
1416+
} else if (floatType.isBF16()) {
1417+
APInt data(16, operands[2]);
1418+
value = APFloat(APFloat::BFloat(), data);
14161419
}
14171420

14181421
auto attr = opBuilder.getFloatAttr(floatType, value);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -999,21 +999,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
999999

10001000
auto resultID = getNextID();
10011001
APFloat value = floatAttr.getValue();
1002+
const llvm::fltSemantics *semantics = &value.getSemantics();
10021003

10031004
auto opcode =
10041005
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
10051006

1006-
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1007+
if (semantics == &APFloat::IEEEsingle()) {
10071008
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
10081009
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1009-
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1010+
} else if (semantics == &APFloat::IEEEdouble()) {
10101011
struct DoubleWord {
10111012
uint32_t word1;
10121013
uint32_t word2;
10131014
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
10141015
encodeInstructionInto(typesGlobalValues, opcode,
10151016
{typeID, resultID, words.word1, words.word2});
1016-
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1017+
} else if (semantics == &APFloat::IEEEhalf() ||
1018+
semantics == &APFloat::BFloat()) {
10171019
uint32_t word =
10181020
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
10191021
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});

0 commit comments

Comments
 (0)