Skip to content

Commit 2fed198

Browse files
authored
Bump water LLVM to 478e45fb94e541dfd3a53a23bbc8ed98337b8a77 (#446)
- increase version requirement for nanobind to reflect MLIR requirements for stubgen - work around the bug in MLIR installation not including the generated stubs by pre-generating them - update successor interfaces - update amdgpu.mfma syntax - account for arith.andi folder that prevents the operation from being created in the first place Signed-off-by: Alex Zinenko <[email protected]>
1 parent 8f88783 commit 2fed198

File tree

7 files changed

+30
-34
lines changed

7 files changed

+30
-34
lines changed

.github/workflows/ci-gpu.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ jobs:
9797
-DLLVM_DISTRIBUTION_COMPONENTS="llvm-headers;llvm-libraries;cmake-exports;FileCheck;count;not;mlir-headers;mlir-libraries;mlir-cmake-exports;mlir-tblgen;mlir-python-sources" \
9898
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
9999
-DCMAKE_INSTALL_PREFIX=${GITHUB_WORKSPACE}/llvm-mlir/_mlir_install
100+
echo "INFO: working around a missing dependency on stubgen"
101+
ninja MLIRPythonModules.extension._mlir.dso._mlir.type_stubs
100102
ninja install-distribution-stripped
101103
popd
102104

water/lib/Dialect/Wave/IR/WaveOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,15 @@ bool wave::IterateOp::areTypesCompatible(mlir::Type lhs, mlir::Type rhs) {
112112
}
113113

114114
mlir::OperandRange
115-
wave::IterateOp::getEntrySuccessorOperands(mlir::RegionBranchPoint point) {
115+
wave::IterateOp::getEntrySuccessorOperands(mlir::RegionSuccessor) {
116116
return getIterArgs();
117117
}
118118

119119
void wave::IterateOp::getSuccessorRegions(
120120
mlir::RegionBranchPoint point,
121121
::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &regions) {
122122
// May branch into the region or bypass it regardless of the source.
123-
regions.emplace_back(mlir::RegionSuccessor(getResults()));
123+
regions.emplace_back(mlir::RegionSuccessor(getOperation(), getResults()));
124124
regions.emplace_back(
125125
mlir::RegionSuccessor(&getBody(), getBody().front().getArguments()));
126126
}
@@ -544,6 +544,6 @@ LogicalResult WriteOp::verify() {
544544
//-----------------------------------------------------------------------------
545545

546546
mlir::MutableOperandRange
547-
wave::YieldOp::getMutableSuccessorOperands(mlir::RegionBranchPoint) {
547+
wave::YieldOp::getMutableSuccessorOperands(mlir::RegionSuccessor) {
548548
return getValuesMutable();
549549
}

water/llvm-sha.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ec3cf67434ba361124cfbb548e93589acd0d3cf2
1+
478e45fb94e541dfd3a53a23bbc8ed98337b8a77

water/requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Development requirements for building Python bindings
2-
nanobind>=2.4, <3.0
2+
nanobind>=2.9, <3.0
33
pybind11>=2.10.0, <=2.13.6
44
numpy
55
lit

water/test/Dialect/Wave/lower-wave-to-mlir.mlir

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_t
4343
%acc = wave.register %cst_f32 : vector<4xf32>
4444

4545
// CHECK-NOT: wave.mma
46-
// CHECK: amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]]
47-
// CHECK-SAME: blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32
46+
// CHECK: amdgpu.mfma 16x16x16 %[[LHS]] * %[[RHS]] + %[[ACC]]
4847
// CHECK-SAME: blgp = none
4948
// CHECK-SAME: vector<4xf16>, vector<4xf16>, vector<4xf32>
5049
%res = wave.mma %lhs, %rhs, %acc {kind = #wave.mma_kind<f32_16x16x16_f16>}
@@ -93,88 +92,88 @@ module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_t
9392
// f16 kinds
9493
// CHECK-NOT: wave.mma
9594
// CHECK: amdgpu.mfma
96-
// CHECK-SAME: k = 16 : i32, m = 16 : i32, n = 16 : i32
95+
// CHECK-SAME: 16x16x16
9796
%0 = wave.mma %lhs_f16, %rhs_f16, %acc_f32_4 {kind = #wave.mma_kind<f32_16x16x16_f16>}
9897
: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32>
9998
// CHECK-NOT: wave.mma
10099
// CHECK: amdgpu.mfma
101-
// CHECK-SAME: k = 8 : i32, m = 32 : i32, n = 32 : i32
100+
// CHECK-SAME: 32x32x8
102101
%1 = wave.mma %lhs_f16, %rhs_f16, %acc_f32_16 {kind = #wave.mma_kind<f32_32x32x8_f16>}
103102
: (vector<4xf16>, vector<4xf16>, vector<16xf32>) -> vector<16xf32>
104103
// CHECK-NOT: wave.mma
105104
// CHECK: amdgpu.mfma
106-
// CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
105+
// CHECK-SAME: 16x16x32
107106
%2 = wave.mma %lhs_f16_w8, %rhs_f16_w8, %acc_f32_4 {kind = #wave.mma_kind<f32_16x16x32_k8_f16>}
108107
: (vector<8xf16>, vector<8xf16>, vector<4xf32>) -> vector<4xf32>
109108
// CHECK-NOT: wave.mma
110109
// CHECK: amdgpu.mfma
111-
// CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
110+
// CHECK-SAME: 32x32x16
112111
%3 = wave.mma %lhs_f16_w8, %rhs_f16_w8, %acc_f32_16 {kind = #wave.mma_kind<f32_32x32x16_k8_f16>}
113112
: (vector<8xf16>, vector<8xf16>, vector<16xf32>) -> vector<16xf32>
114113
// CHECK-NOT: wave.mma
115114
// CHECK: amdgpu.mfma
116-
// CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
115+
// CHECK-SAME: 32x32x16
117116
%4 = wave.mma %lhs_f16_w8, %rhs_f16_w8, %acc_f32_16 {kind = #wave.mma_kind<f32_32x32x16_f16>}
118117
: (vector<8xf16>, vector<8xf16>, vector<16xf32>) -> vector<16xf32>
119118
// CHECK-NOT: wave.mma
120119
// CHECK: amdgpu.mfma
121-
// CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
120+
// CHECK-SAME: 16x16x32
122121
%5 = wave.mma %lhs_f16_w8, %rhs_f16_w8, %acc_f32_4 {kind = #wave.mma_kind<f32_16x16x32_f16>}
123122
: (vector<8xf16>, vector<8xf16>, vector<4xf32>) -> vector<4xf32>
124123

125124
// bf16 kinds
126125
// CHECK-NOT: wave.mma
127126
// CHECK: amdgpu.mfma
128-
// CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
127+
// CHECK-SAME: 32x32x16
129128
%6 = wave.mma %lhs_bf16, %rhs_bf16, %acc_f32_16 {kind = #wave.mma_kind<f32_32x32x16_bf16>}
130129
: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>) -> vector<16xf32>
131130
// CHECK-NOT: wave.mma
132131
// CHECK: amdgpu.mfma
133-
// CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
132+
// CHECK-SAME: 16x16x32
134133
%7 = wave.mma %lhs_bf16, %rhs_bf16, %acc_f32_4 {kind = #wave.mma_kind<f32_16x16x32_bf16>}
135134
: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>) -> vector<4xf32>
136135

137136
// f8 kinds
138137
// CHECK-NOT: wave.mma
139138
// CHECK: amdgpu.mfma
140-
// CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
139+
// CHECK-SAME: 16x16x32
141140
%8 = wave.mma %lhs_f8, %rhs_f8, %acc_f32_4 {kind = #wave.mma_kind<f32_16x16x32_f8>}
142141
: (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
143142
// CHECK-NOT: wave.mma
144143
// CHECK: amdgpu.mfma
145-
// CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
144+
// CHECK-SAME: 32x32x16
146145
%9 = wave.mma %lhs_f8, %rhs_f8, %acc_f32_16 {kind = #wave.mma_kind<f32_32x32x16_f8>}
147146
: (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>) -> vector<16xf32>
148147
// CHECK-NOT: wave.mma
149148
// CHECK: amdgpu.mfma
150-
// CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
149+
// CHECK-SAME: 16x16x32
151150
%10 = wave.mma %lhs_f8, %rhs_f8, %acc_f32_4 {kind = #wave.mma_kind<f32_16x16x32_k4_f8>}
152151
: (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>) -> vector<4xf32>
153152
// CHECK-NOT: wave.mma
154153
// CHECK: amdgpu.mfma
155-
// CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
154+
// CHECK-SAME: 32x32x16
156155
%11 = wave.mma %lhs_f8, %rhs_f8, %acc_f32_16 {kind = #wave.mma_kind<f32_32x32x16_k4_f8>}
157156
: (vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>) -> vector<16xf32>
158157

159158
// i8 kinds
160159
// CHECK-NOT: wave.mma
161160
// CHECK: amdgpu.mfma
162-
// CHECK-SAME: k = 16 : i32, m = 16 : i32, n = 16 : i32
161+
// CHECK-SAME: 16x16x16
163162
%12 = wave.mma %lhs_i8, %rhs_i8, %acc_i32_4 {kind = #wave.mma_kind<i32_16x16x16_i8>}
164163
: (vector<4xi8>, vector<4xi8>, vector<4xi32>) -> vector<4xi32>
165164
// CHECK-NOT: wave.mma
166165
// CHECK: amdgpu.mfma
167-
// CHECK-SAME: k = 8 : i32, m = 32 : i32, n = 32 : i32
166+
// CHECK-SAME: 32x32x8
168167
%13 = wave.mma %lhs_i8, %rhs_i8, %acc_i32_16 {kind = #wave.mma_kind<i32_32x32x8_i8>}
169168
: (vector<4xi8>, vector<4xi8>, vector<16xi32>) -> vector<16xi32>
170169
// CHECK-NOT: wave.mma
171170
// CHECK: amdgpu.mfma
172-
// CHECK-SAME: k = 32 : i32, m = 16 : i32, n = 16 : i32
171+
// CHECK-SAME: 16x16x32
173172
%14 = wave.mma %lhs_i8_w8, %rhs_i8_w8, %acc_i32_4 {kind = #wave.mma_kind<i32_16x16x32_i8>}
174173
: (vector<8xi8>, vector<8xi8>, vector<4xi32>) -> vector<4xi32>
175174
// CHECK-NOT: wave.mma
176175
// CHECK: amdgpu.mfma
177-
// CHECK-SAME: k = 16 : i32, m = 32 : i32, n = 32 : i32
176+
// CHECK-SAME: 32x32x16
178177
%15 = wave.mma %lhs_i8_w8, %rhs_i8_w8, %acc_i32_16 {kind = #wave.mma_kind<i32_32x32x16_i8>}
179178
: (vector<8xi8>, vector<8xi8>, vector<16xi32>) -> vector<16xi32>
180179

water/test/Transforms/assert-in-bounds.mlir

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ func.func @shape_static_one_index_dynamic(%memref: memref<5x2xf32>, %i: index) -
5151
// CHECK: %[[BOUND:.+]] = arith.andi %[[LB]], %[[UB]]
5252
// PERDIM: cf.assert %[[BOUND]], "memref access out of bounds along dimension 1"
5353
//
54-
// COMPOUND: %[[COMPOUND:.+]] = arith.andi %[[BOUND]], %[[TRUE]]
55-
// COMPOUND: cf.assert %[[COMPOUND]], "memref access out of bounds"
54+
// COMPOUND: cf.assert %[[BOUND]], "memref access out of bounds"
5655
// INPLC: memref.load
5756
// SPEC-NOT: memref.load
5857
%1 = memref.load %memref[%0, %i] : memref<5x2xf32>
@@ -77,18 +76,14 @@ func.func @shape_dynamic(%memref: memref<?x?xf32>) -> f32 {
7776
// CHECK: %[[DIM0:.+]] = memref.dim %{{.*}}, %[[ZERO2]]
7877
// Note that folding changed index0 < dim0 into dim0 > index0.
7978
// CHECK: %[[UB0:.+]] = arith.cmpi sgt, %[[DIM0]], %[[INDEX0]]
80-
// PERDIM: %[[BOUND0:.+]] = arith.andi %[[UB0]]
81-
// PERDIM: cf.assert %[[BOUND0]], "memref access out of bounds along dimension 0"
82-
// COMPOUND: %[[PREBOUND0:.+]] = arith.andi %[[UB0]]
83-
// COMPOUND: %[[BOUND0:.+]] = arith.andi %[[PREBOUND0]]
79+
// PERDIM: cf.assert %[[UB0]], "memref access out of bounds along dimension 0"
8480
//
8581
// CHECK: %[[ONE1:.+]] = arith.constant 1 : index
8682
// CHECK: %[[DIM1:.+]] = memref.dim %{{.*}}, %[[ONE1]]
8783
// CHECK: %[[UB1:.+]] = arith.cmpi sgt, %[[DIM1]], %[[INDEX1]]
88-
// CHECK: %[[BOUND1:.+]] = arith.andi %[[UB1]]
89-
// PERDIM: cf.assert %[[BOUND1]], "memref access out of bounds along dimension 1"
84+
// PERDIM: cf.assert %[[UB1]], "memref access out of bounds along dimension 1"
9085
//
91-
// COMPOUND: %[[COMPOUND:.+]] = arith.andi %[[BOUND0]], %[[BOUND1]]
86+
// COMPOUND: %[[COMPOUND:.+]] = arith.andi %[[UB0]], %[[UB1]]
9287
// COMPOUND: cf.assert %[[COMPOUND]], "memref access out of bounds"
9388
//
9489
// INPLC: memref.load

water/test/Transforms/lowered_gemm_pipelined.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ module attributes {transform.with_named_sequence} {
5555
%58 = vector.load %view[%5, %7] : memref<64x36xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
5656
%59 = vector.load %view_4[%8, %6] : memref<64x36xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
5757
%60 = vector.load %view_4[%8, %7] : memref<64x36xbf16, #gpu.address_space<workgroup>>, vector<8xbf16>
58-
%61 = amdgpu.mfma %59 * %57 + %arg4 {blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
59-
%62 = amdgpu.mfma %60 * %58 + %61 {blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32} blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
58+
%61 = amdgpu.mfma 32x32x16 %59 * %57 + %arg4 {blocks = 1 : i32} blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
59+
%62 = amdgpu.mfma 32x32x16 %60 * %58 + %61 {blocks = 1 : i32} blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
6060
scf.yield %62, %arg7, %arg8, %55, %56 : vector<16xf32>, vector<8xbf16>, vector<8xbf16>, vector<8xbf16>, vector<8xbf16>
6161
}
6262
%16 = vector.extract_strided_slice %15#0 {offsets = [0], sizes = [1], strides = [1]} : vector<16xf32> to vector<1xf32>

0 commit comments

Comments
 (0)