Skip to content

Commit 9bb1a2b

Browse files
jtuylsqedawkins
andauthored
[ROCM] Port mlir ukernels to ukernel descriptor lowering flow (iree-org#21683)
This copies all ukernels from the tuning spec to the ukernel descriptor based lowering and PDL patterns. This doesn't remove the ukernels in the spec yet as that requires the usage of `--iree-codegen-enable-default-tuning-specs=true` to be updated to `--iree-hip-enable-tensor-ukernels` everywhere, which imo is better done in a separate PR. The ukernels and matching patterns being copied in this PR: - pingpong_large_f8_expanded - pingpong_large_f16 - pingpong_medium_f16_expanded - pingpong_large_f16_expanded - pingpong_large_bf16 - pingpong_medium_bf16_expanded - pingpong_large_bf16_expanded Note that the mmt_2048x1280x5120_f16_f16_f32 matching and annotation is not ported as I think this is not reachable due to pingpong_large_f16 matching the same and taking precedence. --------- Signed-off-by: Jorn Tuyls <[email protected]> Co-authored-by: Quinn Dawkins <[email protected]>
1 parent 46de78a commit 9bb1a2b

File tree

13 files changed

+2579
-33
lines changed

13 files changed

+2579
-33
lines changed

compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/ApplyBuiltinPDLPatterns.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "iree/compiler/Dialect/HAL/IR/HALDialect.h"
1717
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
1818
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
19+
#include "iree/compiler/Utils/ShapeUtils.h"
1920
#include "llvm/ADT/SmallVectorExtras.h"
2021
#include "llvm/Support/FormatVariadic.h"
2122
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -61,7 +62,8 @@ static LogicalResult annotateOperation(PatternRewriter &rewriter,
6162
return rewriter.notifyMatchFailure(rootOp,
6263
"expected StringAttr for attr name.");
6364
}
64-
rootOp->setAttr(strName.strref(), annotation);
65+
rewriter.modifyOpInPlace(
66+
rootOp, [&]() { rootOp->setAttr(strName.strref(), annotation); });
6567
return success();
6668
}
6769

@@ -75,7 +77,6 @@ static LogicalResult matchContraction(PatternRewriter &rewriter,
7577
return rewriter.notifyMatchFailure(rootOp,
7678
"not a contraction like linalg op");
7779
}
78-
7980
if (linalgOp.getIndexingMaps() != indexingMaps) {
8081
return rewriter.notifyMatchFailure(rootOp, "indexing maps mismatch");
8182
}
@@ -102,6 +103,9 @@ static LogicalResult dimIsMultipleOf(PatternRewriter &rewriter, Value value,
102103
if (!dim) {
103104
return failure();
104105
}
106+
if (dimInt.getInt() >= shapedType.getRank()) {
107+
return failure();
108+
}
105109
auto divisorInt = dyn_cast<IntegerAttr>(divisor);
106110
if (!divisor) {
107111
return failure();
@@ -140,6 +144,9 @@ static LogicalResult dimIsBound(PatternRewriter &rewriter, Value value,
140144
if (!dimInt) {
141145
return failure();
142146
}
147+
if (dimInt.getInt() >= shapedType.getRank()) {
148+
return failure();
149+
}
143150
if (auto lowerBoundInt = dyn_cast<IntegerAttr>(lowerBound)) {
144151
FailureOr<int64_t> constantLb =
145152
ValueBoundsConstraintSet::computeConstantBound(

compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,77 @@ func.func @negative_matmul_f8_dynamic_lower_bound(%arg0: index) -> tensor<1x128x
133133
// CHECK-LABEL: @negative_matmul_f8_dynamic_lower_bound
134134
// CHECK-NOT: compilation_info = #iree_codegen.compilation_info
135135
// CHECK-NOT: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_medium_f8_expanded", tensor>
136+
137+
// -----
138+
139+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
140+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
141+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
142+
func.func @negative_matmul_f16(%arg0: tensor<256x4096xf16>, %arg1: tensor<1024x4096xf16>) -> tensor<256x1024xf32> {
143+
%cst = arith.constant 0.000000e+00 : f32
144+
%0 = tensor.empty() : tensor<256x1024xf32>
145+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x1024xf32>) -> tensor<256x1024xf32>
146+
%2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<256x4096xf16>, tensor<1024x4096xf16>) outs(%1 : tensor<256x1024xf32>) {
147+
^bb0(%in: f16, %in_4: f16, %out: f32):
148+
%12 = arith.extf %in : f16 to f32
149+
%13 = arith.extf %in_4 : f16 to f32
150+
%14 = arith.mulf %12, %13 : f32
151+
%15 = arith.addf %out, %14 : f32
152+
linalg.yield %15 : f32
153+
} -> tensor<256x1024xf32>
154+
return %2 : tensor<256x1024xf32>
155+
}
156+
// CHECK-LABEL: @negative_matmul_f16
157+
// CHECK-NOT: compilation_info = #iree_codegen.compilation_info
158+
// CHECK-NOT: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
159+
160+
// -----
161+
162+
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
163+
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
164+
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
165+
func.func @negative_matmul_bf16(%arg0: tensor<256x4096xbf16>, %arg1: tensor<1024x4096xbf16>) -> tensor<256x1024xf32> {
166+
%cst = arith.constant 0.000000e+00 : f32
167+
%0 = tensor.empty() : tensor<256x1024xf32>
168+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x1024xf32>) -> tensor<256x1024xf32>
169+
%2 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<256x4096xbf16>, tensor<1024x4096xbf16>) outs(%1 : tensor<256x1024xf32>) {
170+
^bb0(%in: bf16, %in_4: bf16, %out: f32):
171+
%12 = arith.extf %in : bf16 to f32
172+
%13 = arith.extf %in_4 : bf16 to f32
173+
%14 = arith.mulf %12, %13 : f32
174+
%15 = arith.addf %out, %14 : f32
175+
linalg.yield %15 : f32
176+
} -> tensor<256x1024xf32>
177+
return %2 : tensor<256x1024xf32>
178+
}
179+
// CHECK-LABEL: @negative_matmul_bf16
180+
// CHECK-NOT: compilation_info = #iree_codegen.compilation_info
181+
// CHECK-NOT: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor
182+
183+
// -----
184+
185+
// The dynamic dimension is a multiple of 256, but doesn't have a lower bound of 256.
186+
187+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
188+
#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
189+
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
190+
func.func @negative_matmul_bf16_dynamic_lower_bound(%arg0: index) -> tensor<1x256x1024xf32> {
191+
%cst = arith.constant 0.000000e+00 : f32
192+
%0 = util.assume.int %arg0<umin = 128, udiv = 256> : index
193+
%1 = tensor.empty(%0) : tensor<1x256x?xbf16>
194+
%2 = tensor.empty(%0) : tensor<1024x?xbf16>
195+
%3 = tensor.empty() : tensor<1x256x1024xf32>
196+
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<1x256x1024xf32>) -> tensor<1x256x1024xf32>
197+
%5 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%1, %2 : tensor<1x256x?xbf16>, tensor<1024x?xbf16>) outs(%4 : tensor<1x256x1024xf32>) {
198+
^bb0(%in: bf16, %in_0: bf16, %out: f32):
199+
%6 = arith.extf %in : bf16 to f32
200+
%7 = arith.extf %in_0 : bf16 to f32
201+
%8 = arith.mulf %6, %7 : f32
202+
%9 = arith.addf %out, %8 : f32
203+
linalg.yield %9 : f32
204+
} -> tensor<1x256x1024xf32>
205+
return %5 : tensor<1x256x1024xf32>
206+
}
207+
// CHECK-LABEL: @negative_matmul_bf16_dynamic_lower_bound
208+
// CHECK-NOT: compilation_info = #iree_codegen.compilation_info
209+
// CHECK-NOT: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_large_bf16_expanded", tensor>

0 commit comments

Comments
 (0)