Skip to content

Commit f14e6b2

Browse files
[ROCM] Update Ukernel infra to allow ROCM-specific bitcode ukernel lowering (iree-org#21681)
-- This commit introduces the following changes :- a. Adds `createAndReplaceWithUkernel` to UkernelProviderInterface. b. Creates a `ROCMUkernelProviderAttr` implementing `createAndReplaceWithUkernel`. c. Updates LowerUkernelDescriptors to make use of `createAndReplaceWithUkernel` while lowering to ukernel. Signed-off-by: Abhishek Varma <[email protected]>
1 parent 25d8239 commit f14e6b2

File tree

9 files changed

+206
-13
lines changed

9 files changed

+206
-13
lines changed

compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ iree_td_library(
2828
include = ["*.td"],
2929
),
3030
deps = [
31+
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:td_files",
3132
"//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
3233
"@llvm-project//mlir:DialectUtilsTdFiles",
3334
"@llvm-project//mlir:OpBaseTdFiles",
@@ -64,6 +65,7 @@ iree_compiler_cc_library(
6465
"@llvm-project//llvm:Support",
6566
"@llvm-project//mlir:DialectUtils",
6667
"@llvm-project//mlir:IR",
68+
"@llvm-project//mlir:LinalgDialect",
6769
"@llvm-project//mlir:Parser",
6870
"@llvm-project//mlir:Support",
6971
],
@@ -102,5 +104,6 @@ iree_gentbl_cc_library(
102104
td_file = "ROCMAttrs.td",
103105
deps = [
104106
":td_files",
107+
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:td_files",
105108
],
106109
)

compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ iree_cc_library(
3030
::ROCMDialectGen
3131
LLVMSupport
3232
MLIRIR
33+
MLIRLinalgDialect
3334
MLIRParser
3435
MLIRSupport
3536
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect

compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
#include "compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.h"
88
#include "compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMDialect.h"
9+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
910
#include "llvm/ADT/TypeSwitch.h"
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1012
#include "mlir/IR/Attributes.h"
1113
#include "mlir/IR/DialectImplementation.h"
1214
#include "mlir/IR/OpDefinition.h"
@@ -27,6 +29,69 @@ BuiltinTuningModuleAttr::getModule(Operation * /*annotationSite*/) const {
2729
return rocmDialect.getOrLoadBuiltinModule(getBuiltinFilename());
2830
}
2931

32+
//===----------------------------------------------------------------------===//
33+
// UKernelProviderAttr
34+
//===----------------------------------------------------------------------===//
35+
36+
/// Utility function to help create and replace argmax linalg with a ukernel.
37+
static LogicalResult handleArgmaxUkernel(
38+
RewriterBase &rewriter, StringRef name, DictionaryAttr targetConfiguration,
39+
Operation *contextualOp, SmallVectorImpl<Value> &inputs,
40+
SmallVectorImpl<Value> &outputs, SmallVectorImpl<Value> &otherOperands) {
41+
auto genericOp = dyn_cast<linalg::GenericOp>(contextualOp);
42+
if (!genericOp) {
43+
return rewriter.notifyMatchFailure(
44+
genericOp, "expected a linalg.generic op for argmax");
45+
}
46+
// Currently only support 1D reduction, where reduction is on fastest dim.
47+
// Tiling argmax ukernel is also set to enforce this structure.
48+
const int kReductionDim = genericOp.getNumLoops() - 1;
49+
Location loc = genericOp.getLoc();
50+
Value reductionDimSize = rewriter.create<tensor::DimOp>(
51+
loc, genericOp.getDpsInputOperand(0)->get(), kReductionDim);
52+
// `returnsMaxValue` differentiates between the two argmax versions :-
53+
// 1. Returns only the index of the max value (returnsMaxValue == true)
54+
// 2. Returns both the max value as well as the corresponding index.
55+
bool returnsMaxValue = genericOp.getResults()[0].use_empty();
56+
Value writeMaxValueFlag = rewriter.create<arith::ConstantOp>(
57+
loc, rewriter.getI1Type(), rewriter.getBoolAttr(!returnsMaxValue));
58+
llvm::append_values(otherOperands, reductionDimSize, writeMaxValueFlag);
59+
MLIRContext *context = rewriter.getContext();
60+
auto fnDefAttrs = DictionaryAttr::get(
61+
context, {{"vm.import.module", StringAttr::get(context, "rocm")}});
62+
auto ukernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
63+
loc, contextualOp->getResultTypes(), name, inputs, outputs, otherOperands,
64+
fnDefAttrs, /*num_strided_outer_dims=*/0);
65+
if (returnsMaxValue) {
66+
rewriter.replaceAllUsesWith(genericOp.getResults()[1],
67+
ukernelOp.getResults()[1]);
68+
return success();
69+
}
70+
ResultRange origResults = genericOp.getResults();
71+
ResultRange newResults = ukernelOp.getResults();
72+
if (origResults.size() != newResults.size()) {
73+
return rewriter.notifyMatchFailure(genericOp, "result count mismatch");
74+
}
75+
rewriter.replaceAllUsesWith(genericOp.getResults()[0],
76+
ukernelOp.getResults()[0]);
77+
rewriter.replaceAllUsesWith(genericOp.getResults()[1],
78+
ukernelOp.getResults()[1]);
79+
return success();
80+
}
81+
82+
std::optional<LogicalResult> UKernelProviderAttr::createAndReplaceWithUkernelOp(
83+
RewriterBase &rewriter, StringRef name, DictionaryAttr targetConfiguration,
84+
Operation *contextualOp, SmallVectorImpl<Value> &inputs,
85+
SmallVectorImpl<Value> &outputs,
86+
SmallVectorImpl<Value> &otherOperands) const {
87+
if (name.contains("argmax")) {
88+
return handleArgmaxUkernel(rewriter, name, targetConfiguration,
89+
contextualOp, inputs, outputs, otherOperands);
90+
}
91+
// TODO(avarma): Add multi_mfma ukernel support via descriptors.
92+
return std::nullopt;
93+
}
94+
3095
//===----------------------------------------------------------------------===//
3196
// Attribute Registration
3297
//===----------------------------------------------------------------------===//

compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define IREE_PLUGINS_TARGET_ROCM_DIALECT_ROCMATTRS
99

1010
include "ROCMDialect.td"
11+
include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td"
1112
include "iree/compiler/Dialect/Util/IR/UtilInterfaces.td"
1213
include "mlir/IR/AttrTypeBase.td"
1314
include "mlir/IR/BuiltinTypeInterfaces.td"
@@ -38,4 +39,18 @@ def ROCM_BuiltinTuningModuleAttr :
3839
let assemblyFormat = "`<` $builtin_filename `>`";
3940
}
4041

42+
def ROCM_UKernelProviderAttr :
43+
AttrDef<ROCM_Dialect, "UKernelProvider", [
44+
DeclareAttrInterfaceMethods<IREECodegen_UKernelProviderInterface, [
45+
"createAndReplaceWithUkernelOp"
46+
]>
47+
]> {
48+
let mnemonic = "ukernel_provider";
49+
let summary = [{
50+
An attribute that provides context specific ukernel implementations for ROCM.
51+
}];
52+
let parameters = (ins);
53+
let assemblyFormat = [{}];
54+
}
55+
4156
#endif // IREE_PLUGINS_TARGET_ROCM_DIALECT_ROCMATTRS
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
// RUN: iree-opt %s | iree-opt | FileCheck %s
22

33
module @tuning_module attributes {
4+
iree_codegen.ukernel_provider = #rocm.ukernel_provider,
45
rocm.spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir"> } {
56
}
67

78
// CHECK-LABEL: module @tuning_module
9+
// CHECK-SAME: iree_codegen.ukernel_provider = #rocm.ukernel_provider
810
// CHECK-SAME: #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">

compiler/plugins/target/ROCM/ROCMTarget.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,10 @@ class ROCMTargetBackend final : public TargetBackend {
330330
}
331331

332332
addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels));
333+
if (options.enableROCMUkernels != "none") {
334+
addConfig("iree_codegen.ukernel_provider",
335+
IREE::ROCM::UKernelProviderAttr::get(context));
336+
}
333337
if (options.wavesPerEu > 0) {
334338
addConfigWavesPerEu(b.getContext(), options.wavesPerEu, configItems);
335339
}

compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,10 @@ static void populateCastConversions(TypeConverter &converter) {
7878
//===----------------------------------------------------------------------===//
7979

8080
/// Converts an operation to `iree_codegen.ukernel.generic`.
81-
///
82-
/// NOTE: This is primarily an example implementation with inherent limitations.
83-
/// The generic approach used here cannot fulfill the requirements of all
84-
/// ukernel implementations. Real ukernels often need additional
85-
/// context-specific operands (e.g., runtime shapes or algorithm-specific
86-
/// parameters) that cannot be generically inferred from the source operation
87-
/// alone.
88-
static LogicalResult convertToUKernelGeneric(RewriterBase &rewriter,
89-
Operation *op, StringRef name) {
81+
static LogicalResult
82+
convertToUKernelGeneric(RewriterBase &rewriter, Operation *op, StringRef name,
83+
IREE::Codegen::UKernelProviderInterface &provider,
84+
DictionaryAttr targetConfiguration) {
9085
SmallVector<Value> tensorInputs;
9186
SmallVector<Value> tensorOutputs;
9287
SmallVector<Value> otherOperands;
@@ -113,7 +108,18 @@ static LogicalResult convertToUKernelGeneric(RewriterBase &rewriter,
113108
}
114109
}
115110
}
111+
116112
rewriter.setInsertionPoint(op);
113+
if (provider) {
114+
std::optional<LogicalResult> retVal =
115+
provider.createAndReplaceWithUkernelOp(
116+
rewriter, name, targetConfiguration, op, tensorInputs,
117+
tensorOutputs, otherOperands);
118+
if (retVal)
119+
return retVal.value();
120+
}
121+
// Default ukernel generic op is created when a provider doesn't exist or when
122+
// the provider doesn't implement the replacement method.
117123
rewriter.replaceOpWithNewOp<IREE::Codegen::UKernelGenericOp>(
118124
op, op->getResults().getTypes(), name, tensorInputs, tensorOutputs,
119125
otherOperands, DictionaryAttr(),
@@ -249,7 +255,8 @@ processUKernelKind(Operation *root, IREE::Codegen::UKernelArgumentKind kind) {
249255
for (auto [op, name] : opsToConvert) {
250256
switch (kind) {
251257
case IREE::Codegen::UKernelArgumentKind::Bitcode: {
252-
if (failed(convertToUKernelGeneric(rewriter, op, name))) {
258+
if (failed(convertToUKernelGeneric(rewriter, op, name, provider,
259+
targetAttr.getConfiguration()))) {
253260
return op->emitOpError()
254261
<< "failed to convert to ukernel.generic with name " << name;
255262
}

compiler/src/iree/compiler/Codegen/Common/test/lower_ukernel_bitcode_descriptor.mlir

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
// RUN: iree-opt --iree-codegen-lower-bitcode-ukernels %s | FileCheck %s
1+
// RUN: iree-opt --iree-codegen-lower-bitcode-ukernels --split-input-file %s | FileCheck %s
22

3-
// CHECK-LABEL: @ukernel_test
3+
// CHECK-LABEL: @ukernel_test_without_provider
44
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<16x32xf32>
55
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<16x32xf32>
66
// CHECK-NOT: linalg.generic
@@ -14,7 +14,7 @@
1414
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
1515
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
1616
module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
17-
func.func @ukernel_test(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xf32>) -> tensor<16x16xf32> {
17+
func.func @ukernel_test_without_provider(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xf32>) -> tensor<16x16xf32> {
1818
%cst = arith.constant 0.000000e+00 : f32
1919
%0 = tensor.empty() : tensor<16x16xf32>
2020
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x16xf32>) -> tensor<16x16xf32>
@@ -27,3 +27,73 @@ module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
2727
return %2 : tensor<16x16xf32>
2828
}
2929
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: @pure_argmax_ukernel_test_with_provider
34+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?xf32>
35+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<f32>
36+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<i64>
37+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
38+
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?xf32>
39+
// CHECK: %[[FALSE:.*]] = arith.constant false
40+
// CHECK: %[[MICRO_KERNEL:.+]]:2 = iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f32i64"
41+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
42+
// CHECK-SAME: outs(%[[ARG1]], %[[ARG2]] : tensor<f32>, tensor<i64>)
43+
// CHECK-SAME: (%[[DIM]], %[[FALSE]] : index, i1)
44+
// CHECK-SAME: fn_def_attrs {vm.import.module = "rocm"}
45+
// CHECK-SAME{LITERAL}: strided_dims([[], [], []])
46+
// CHECK: return %[[MICRO_KERNEL]]#1
47+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.ukernel_provider = #rocm.ukernel_provider}>
48+
#map = affine_map<(d0) -> (d0)>
49+
#map1 = affine_map<(d0) -> ()>
50+
module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
51+
func.func @pure_argmax_ukernel_test_with_provider(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tensor<i64>) -> tensor<i64> {
52+
%cst = arith.constant 0.000000e+00 : f32
53+
%0:2 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["reduction"]} ins(%arg0 : tensor<?xf32>) outs(%arg1, %arg2 : tensor<f32>, tensor<i64>) attrs = {iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"iree_uk_amdgpu_argmax_f32i64", bitcode>} {
54+
^bb0(%in: f32, %out: f32, %out_0: i64):
55+
%1 = linalg.index 0 : index
56+
%2 = arith.index_cast %1 : index to i64
57+
%3 = arith.maximumf %in, %out : f32
58+
%4 = arith.cmpf ogt, %in, %out : f32
59+
%5 = arith.select %4, %2, %out_0 : i64
60+
linalg.yield %3, %5 : f32, i64
61+
} -> (tensor<f32>, tensor<i64>)
62+
return %0#1 : tensor<i64>
63+
}
64+
}
65+
66+
// -----
67+
68+
// CHECK-LABEL: @argmax_ukernel_test_with_provider
69+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?xf32>
70+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<f32>
71+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<i64>
72+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
73+
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?xf32>
74+
// CHECK: %[[TRUE:.*]] = arith.constant true
75+
// CHECK: %[[MICRO_KERNEL:.+]]:2 = iree_codegen.ukernel.generic "iree_uk_amdgpu_argmax_f32i64"
76+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
77+
// CHECK-SAME: outs(%[[ARG1]], %[[ARG2]] : tensor<f32>, tensor<i64>)
78+
// CHECK-SAME: (%[[DIM]], %[[TRUE]] : index, i1)
79+
// CHECK-SAME: fn_def_attrs {vm.import.module = "rocm"}
80+
// CHECK-SAME{LITERAL}: strided_dims([[], [], []])
81+
// CHECK: return %[[MICRO_KERNEL]]#0, %[[MICRO_KERNEL]]#1
82+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.ukernel_provider = #rocm.ukernel_provider}>
83+
#map = affine_map<(d0) -> (d0)>
84+
#map1 = affine_map<(d0) -> ()>
85+
module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
86+
func.func @argmax_ukernel_test_with_provider(%arg0: tensor<?xf32>, %arg1: tensor<f32>, %arg2: tensor<i64>) -> (tensor<f32>, tensor<i64>) {
87+
%cst = arith.constant 0.000000e+00 : f32
88+
%0:2 = linalg.generic {indexing_maps = [#map, #map1, #map1], iterator_types = ["reduction"]} ins(%arg0 : tensor<?xf32>) outs(%arg1, %arg2 : tensor<f32>, tensor<i64>) attrs = {iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"iree_uk_amdgpu_argmax_f32i64", bitcode>} {
89+
^bb0(%in: f32, %out: f32, %out_0: i64):
90+
%1 = linalg.index 0 : index
91+
%2 = arith.index_cast %1 : index to i64
92+
%3 = arith.maximumf %in, %out : f32
93+
%4 = arith.cmpf ogt, %in, %out : f32
94+
%5 = arith.select %4, %2, %out_0 : i64
95+
linalg.yield %3, %5 : f32, i64
96+
} -> (tensor<f32>, tensor<i64>)
97+
return %0#0, %0#1 : tensor<f32>, tensor<i64>
98+
}
99+
}

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,32 @@ def IREECodegen_UKernelProviderInterface :
549549
return failure();
550550
}]
551551
>,
552+
InterfaceMethod<
553+
/*desc=*/[{
554+
Creates and replaces the |contextual_op| with a ukernel referenced by
555+
|name| backed by this attribute. Takes |target_configuration| in
556+
case the provider wants to make any feature specific decisions for
557+
which implementation to provide. It takes in |inputs|, |outputs| and
558+
|other_operands| and adds to it a few context-specific as well as
559+
backend specific implementations to create the ukernel op.
560+
Returning std::nullopt indicates falling back to default implementation.
561+
Otherwise success/failure indicates the status of the custom replacement
562+
implementation.
563+
}],
564+
/*retTy=*/"std::optional<::mlir::LogicalResult>",
565+
/*methodName=*/"createAndReplaceWithUkernelOp",
566+
/*args=*/(ins "::mlir::RewriterBase&":$b,
567+
"::mlir::StringRef":$name,
568+
"::mlir::DictionaryAttr":$target_configuration,
569+
"::mlir::Operation *":$contextual_op,
570+
"::llvm::SmallVectorImpl<::mlir::Value>&":$inputs,
571+
"::llvm::SmallVectorImpl<::mlir::Value>&":$outputs,
572+
"::llvm::SmallVectorImpl<::mlir::Value>&":$other_operands),
573+
/*methodBody=*/"",
574+
/*defaultImplementation=*/[{
575+
return failure();
576+
}]
577+
>,
552578
];
553579
}
554580

0 commit comments

Comments
 (0)