Skip to content

Commit ee45972

Browse files
author
Menooker
authored
Merge pull request #4 from dchigarev/llvm_upd
Update LLVM
2 parents f1140b7 + 2a9013d commit ee45972

File tree

8 files changed

+54
-56
lines changed

8 files changed

+54
-56
lines changed

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
08a61eb01172054fc5f8c78ff527f01d9768569b
1+
f06563a5c0d239a6b98f74db522681613254ad08

build_tools/patches/0001-Add-support-for-VectorAnyINTEL-capability.patch

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ requirement initially, then do the check for capability inferred extension.
4545
17 files changed, 319 insertions(+), 68 deletions(-)
4646

4747
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
48-
index 6ec97e17c5dc..75e42c024553 100644
48+
index af0b2624feb3..b6e80c3f9516 100644
4949
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
5050
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
51-
@@ -4138,7 +4138,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
51+
@@ -4142,7 +4142,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
5252
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
5353
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
5454
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
@@ -62,8 +62,8 @@ index 6ec97e17c5dc..75e42c024553 100644
6262
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
6363
// Component type check is done in the type parser for the following SPIR-V
6464
// dialect-specific types so we use "Any" here.
65-
@@ -4189,7 +4194,7 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
66-
"Joint Matrix">;
65+
@@ -4185,7 +4190,7 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
66+
"Cooperative Matrix">;
6767

6868
class SPIRV_VectorOf<Type type> :
6969
- VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
@@ -764,4 +764,4 @@ index 9a2e4cf62e37..31a7f616d648 100644
764764
// CHECK: spirv.CL.fma {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
765765
%13 = spirv.CL.fma %arg0, %arg1, %arg2 : f32
766766
--
767-
2.34.1
767+
2.34.1

build_tools/patches/0009-SPIR-V-Enable-native-bf16-support-in-SPIR-V-dialect.patch

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,63 +21,63 @@ index 22d5afcd7738..de9e11493793 100644
2121
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
2222
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
2323
@@ -82,7 +82,7 @@ class SPIRV_ArithmeticExtendedBinaryOp<string mnemonic,
24-
24+
2525
// -----
26-
26+
2727
-def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]> {
2828
+def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_AnyFloat, [Commutative]> {
2929
let summary = "Floating-point addition of Operand 1 and Operand 2.";
30-
30+
3131
let description = [{
3232
@@ -104,7 +104,7 @@ def SPIRV_FAddOp : SPIRV_ArithmeticBinaryOp<"FAdd", SPIRV_Float, [Commutative]>
33-
33+
3434
// -----
35-
35+
3636
-def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_Float, []> {
3737
+def SPIRV_FDivOp : SPIRV_ArithmeticBinaryOp<"FDiv", SPIRV_AnyFloat, []> {
3838
let summary = "Floating-point division of Operand 1 divided by Operand 2.";
39-
39+
4040
let description = [{
4141
@@ -154,7 +154,7 @@ def SPIRV_FModOp : SPIRV_ArithmeticBinaryOp<"FMod", SPIRV_Float, []> {
42-
42+
4343
// -----
44-
44+
4545
-def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_Float, [Commutative]> {
4646
+def SPIRV_FMulOp : SPIRV_ArithmeticBinaryOp<"FMul", SPIRV_AnyFloat, [Commutative]> {
4747
let summary = "Floating-point multiplication of Operand 1 and Operand 2.";
48-
48+
4949
let description = [{
5050
@@ -229,7 +229,7 @@ def SPIRV_FRemOp : SPIRV_ArithmeticBinaryOp<"FRem", SPIRV_Float, []> {
51-
51+
5252
// -----
53-
53+
5454
-def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_Float, []> {
5555
+def SPIRV_FSubOp : SPIRV_ArithmeticBinaryOp<"FSub", SPIRV_AnyFloat, []> {
5656
let summary = "Floating-point subtraction of Operand 2 from Operand 1.";
57-
57+
5858
let description = [{
5959
@@ -450,7 +450,7 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
6060
);
61-
61+
6262
let results = (outs
6363
- SPIRV_Float:$result
6464
+ SPIRV_AnyFloat:$result
6565
);
66-
66+
6767
let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
6868
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
69-
index 04952dd1dc61..6c9c348490ab 100644
69+
index ddaeb13ef253..336bdcfb7a48 100644
7070
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
7171
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
7272
@@ -343,6 +343,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
7373
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
7474
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
7575
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
7676
+def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;
77-
77+
7878
def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
7979
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
80-
@@ -435,7 +436,7 @@ def SPIRV_ExtensionAttr :
80+
@@ -434,7 +435,7 @@ def SPIRV_ExtensionAttr :
8181
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
8282
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
8383
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
@@ -86,7 +86,7 @@ index 04952dd1dc61..6c9c348490ab 100644
8686
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
8787
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
8888
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
89-
@@ -1193,6 +1194,22 @@ def SPIRV_C_ShaderClockKHR : I32EnumAttrCase<"Shade
89+
@@ -1192,6 +1193,24 @@ def SPIRV_C_ShaderClockKHR : I32EnumAttrCase<"Shade
9090
Extension<[SPV_KHR_shader_clock]>
9191
];
9292
}
@@ -95,11 +95,13 @@ index 04952dd1dc61..6c9c348490ab 100644
9595
+ Extension<[SPV_KHR_bfloat16]>
9696
+ ];
9797
+}
98+
+
9899
+def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
99100
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
100101
+ list<Availability> availability = [
101102
+ Extension<[SPV_KHR_bfloat16]> ];
102103
+}
104+
+
103105
+def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
104106
+ list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
105107
+ list<Availability> availability = [
@@ -109,15 +111,15 @@ index 04952dd1dc61..6c9c348490ab 100644
109111
def SPIRV_C_FragmentFullyCoveredEXT : I32EnumAttrCase<"FragmentFullyCoveredEXT", 5265> {
110112
list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
111113
list<Availability> availability = [
112-
@@ -1491,6 +1508,7 @@ def SPIRV_CapabilityAttr :
114+
@@ -1484,6 +1503,7 @@ def SPIRV_CapabilityAttr :
113115
SPIRV_C_RayQueryKHR, SPIRV_C_RayTracingKHR, SPIRV_C_Float16ImageAMD,
114116
SPIRV_C_ImageGatherBiasLodAMD, SPIRV_C_FragmentMaskAMD, SPIRV_C_StencilExportEXT,
115117
SPIRV_C_ImageReadWriteLodAMD, SPIRV_C_Int64ImageEXT, SPIRV_C_ShaderClockKHR,
116118
+ SPIRV_C_BFloat16TypeKHR, SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
117119
SPIRV_C_FragmentFullyCoveredEXT, SPIRV_C_MeshShadingNV, SPIRV_C_FragmentDensityEXT,
118120
SPIRV_C_ShaderNonUniform, SPIRV_C_RuntimeDescriptorArray,
119121
SPIRV_C_StorageTexelBufferArrayDynamicIndexing, SPIRV_C_RayTracingNV,
120-
@@ -4148,16 +4166,21 @@ def SPIRV_Bool : TypeAlias<I1, "bool">;
122+
@@ -4139,16 +4159,21 @@ def SPIRV_Bool : TypeAlias<I1, "bool">;
121123
def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
122124
def SPIRV_Int16 : TypeAlias<I16, "Int16">;
123125
def SPIRV_Int32 : TypeAlias<I32, "Int32">;
@@ -142,36 +144,34 @@ index 04952dd1dc61..6c9c348490ab 100644
142144
// Component type check is done in the type parser for the following SPIR-V
143145
// dialect-specific types so we use "Any" here.
144146
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
145-
@@ -4180,14 +4203,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
147+
@@ -4169,14 +4194,14 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
146148
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
147149
"any SPIR-V sampled image type">;
148-
150+
149151
-def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>;
150152
+def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
151153
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
152154
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
153155
def SPIRV_Composite :
154156
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
155-
SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
157+
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
156158
def SPIRV_Type : AnyTypeOf<[
157159
- SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
158160
+ SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Vector,
159161
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
160-
SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
161-
SPIRV_AnySampledImage
162-
@@ -4764,6 +4787,12 @@ def SPIRV_FPFMM_AllowReassocINTEL : I32BitEnumAttrCaseBit<"AllowReassocINTEL", 1
163-
];
164-
}
165-
162+
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
163+
]>;
164+
@@ -4745,4 +4770,10 @@ def SPIRV_FPFastMathModeAttr :
165+
SPIRV_FPFMM_AllowReassocINTEL
166+
]>;
167+
166168
+def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0>;
167169
+def SPIRV_FP_Encoding :
168170
+ SPIRV_I32EnumAttr<"FPEncoding", "Valid floating-point encoding", "fpEncoding", [
169171
+ SPIRV_FPE_BFloat16KHR
170172
+ ]>;
171173
+
172-
def SPIRV_FPFastMathModeAttr :
173-
SPIRV_BitEnumAttr<"FPFastMathMode", "Indicates a floating-point fast math flag", "fastmath_mode", [
174-
SPIRV_FPFMM_None, SPIRV_FPFMM_NotNaN, SPIRV_FPFMM_NotInf, SPIRV_FPFMM_NSZ,
174+
#endif // MLIR_DIALECT_SPIRV_IR_BASE
175175
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
176176
index b5ca27d7d753..703920e42c60 100644
177177
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td

lib/Conversion/XeTileToXeGPU/ArithOpConversion.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,7 @@ class SgVectorMultiDimReductionOpPattern
311311

312312
rewriter.setInsertionPoint(op);
313313
// doing reduction on outer dimension
314-
if (mlir::isConstantIntValue(dims[0], 0) &&
315-
mlir::isConstantIntValue(dims[1], 2)) {
314+
if (dims[0] == 0 && dims[1] == 2) {
316315
auto intermediates = lowerOuterReduction(sources, shape, op.getKind(),
317316
loc, elemTy, rewriter);
318317
{
@@ -330,8 +329,7 @@ class SgVectorMultiDimReductionOpPattern
330329
}
331330

332331
// doing reduction on inner dimension
333-
if (mlir::isConstantIntValue(dims[0], 1) &&
334-
mlir::isConstantIntValue(dims[1], 3)) {
332+
if (dims[0] == 1 && dims[1] == 3) {
335333
auto intermediates = lowerInnerReductionWithIntraVectorShuffles(
336334
sources, shape, op.getKind(), loc, elemTy, rewriter);
337335

lib/Dialect/XeTile/Transforms/Blocking.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,8 +458,8 @@ struct VectorMultiDimReductionOpPattern
458458
// will be transformed to
459459
// multi_reduction<add>, %e, %a[1, 3]: vector<16x2x1x16xf16> to
460460
// vector<16x1xf16>
461-
auto dim = mlir::cast<mlir::IntegerAttr>(reductionDims[0]).getInt();
462-
auto newReductionDims = rewriter.getI64ArrayAttr({dim, dim + 2});
461+
auto dim = reductionDims[0];
462+
auto newReductionDims = rewriter.getDenseI64ArrayAttr({dim, dim + 2});
463463

464464
auto newDestShape =
465465
(dim == 0)

lib/Dialect/XeTile/Transforms/Canonicalization.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,12 @@ struct VectorMultiReductionToXeTileReduce
278278
return mlir::failure();
279279
// If result is not 1D, we can not convert it to xetile.reduce. This
280280
// requires that the reduction dimensions has rank 1.
281-
auto reductionDims = op.getReductionDims().getValue();
281+
auto reductionDims = op.getReductionDims();
282282
if (reductionDims.size() != 1)
283283
return mlir::failure();
284284

285285
// Create an equivalent XeTileReduceOp
286-
int64_t reduceDim = llvm::cast<mlir::IntegerAttr>(reductionDims[0])
287-
.getValue()
288-
.getSExtValue();
286+
int64_t reduceDim = reductionDims[0];
289287
auto resultTy = llvm::cast<mlir::VectorType>(op.getType());
290288
auto xetileResultTy = mlir::VectorType::get(
291289
(reduceDim == 0 ? llvm::ArrayRef<int64_t>({1, resultTy.getDimSize(0)})

lib/Transforms/PropagatePackedLayout.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class LayoutAnalysisImpl
160160
public:
161161
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
162162

163-
void visitOperation(mlir::Operation *op,
163+
mlir::LogicalResult visitOperation(mlir::Operation *op,
164164
mlir::ArrayRef<LayoutLattice *> operands,
165165
mlir::ArrayRef<const LayoutLattice *> results) override {
166166
if (mlir::OpTrait::hasElementwiseMappableTraits(op)) {
@@ -182,7 +182,7 @@ class LayoutAnalysisImpl
182182
propagateIfChanged(argLattice, argLattice->meet(tmpLayout));
183183
}
184184

185-
return;
185+
return mlir::success();
186186
}
187187

188188
if (auto dpas = mlir::dyn_cast<mlir::xegpu::DpasOp>(op)) {
@@ -193,12 +193,13 @@ class LayoutAnalysisImpl
193193
propagateIfChanged(operand, operand->meet(std::nullopt));
194194
}
195195
}
196-
return;
196+
return mlir::success();
197197
}
198198

199199
// Unknown ops: mark all args as invalid layout (no layout change).
200200
for (auto operand : operands)
201201
propagateIfChanged(operand, operand->meet(std::nullopt));
202+
return mlir::success();
202203
}
203204

204205
void visitBranchOperand(mlir::OpOperand &operand) override {}

lib/Transforms/VnniTransformation.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class LayoutAnalysisImpl
127127
public:
128128
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
129129

130-
void visitOperation(mlir::Operation *op,
130+
mlir::LogicalResult visitOperation(mlir::Operation *op,
131131
mlir::ArrayRef<LayoutLattice *> operands,
132132
mlir::ArrayRef<const LayoutLattice *> results) override {
133133
// the B operand of a dpas operation is always in vnni layout
@@ -144,7 +144,7 @@ class LayoutAnalysisImpl
144144
// for C operand, it cannot be in vnni format
145145
propagateIfChanged(operands[2], operands[2]->meet(Layout(false)));
146146
}
147-
return;
147+
return mlir::success();
148148
}
149149

150150
if (mlir::OpTrait::hasElementwiseMappableTraits(op)) {
@@ -175,7 +175,7 @@ class LayoutAnalysisImpl
175175
for (auto &&lattice : operands)
176176
propagateIfChanged(lattice, lattice->meet(layout));
177177
}
178-
return;
178+
return mlir::success();
179179
}
180180

181181
if (auto extractStrideSliceOp =
@@ -186,7 +186,7 @@ class LayoutAnalysisImpl
186186
layout = Layout::meet(layout, Layout(isVNNIApplicable(srcTy)));
187187
propagateIfChanged(operands[0], operands[0]->meet(layout));
188188
}
189-
return;
189+
return mlir::success();
190190
}
191191

192192
if (auto extractOp = mlir::dyn_cast<mlir::vector::ExtractOp>(op)) {
@@ -201,12 +201,13 @@ class LayoutAnalysisImpl
201201
layout = Layout::meet(layout, Layout(isVNNIApplicable(vecTy)));
202202
propagateIfChanged(operands[0], operands[0]->meet(layout));
203203
}
204-
return;
204+
return mlir::success();
205205
}
206206

207207
// Unknown ops: mark all args as non-vnni layout (no layout change).
208208
for (auto operand : operands)
209209
propagateIfChanged(operand, operand->join(Layout(false)));
210+
return mlir::success();
210211
}
211212

212213
void visitBranchOperand(mlir::OpOperand &operand) override {}

0 commit comments

Comments
 (0)