Skip to content

Commit cc1a7dd

Browse files
authored
[ROCM] Add ukernel descriptor lowering to pipeline (#21634)
Adds tensor ukernel matching and lowering to the GPU pipeline. --------- Signed-off-by: Jorn Tuyls <[email protected]>
1 parent ac3e153 commit cc1a7dd

File tree

9 files changed

+119
-5
lines changed

9 files changed

+119
-5
lines changed

compiler/plugins/target/ROCM/ROCMTarget.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/Passes.h"
1313
#include "compiler/plugins/target/ROCM/builtins/ukernel/iree_uk_amdgpu_bitcode.h"
1414
#include "iree/compiler/Codegen/Common/Passes.h"
15+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1516
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
1617
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1718
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
@@ -85,6 +86,7 @@ struct ROCMOptions {
8586
bool globalISel = false;
8687

8788
bool specializeDispatches = false;
89+
bool enableTensorUKernels = false;
8890

8991
void bindOptions(OptionsBinder &binder) {
9092
using namespace llvm;
@@ -156,6 +158,9 @@ struct ROCMOptions {
156158
cl::cat(category),
157159
cl::desc(
158160
"Enable runtime specialization of dynamically shaped dispatches."));
161+
binder.opt<bool>("iree-hip-enable-tensor-ukernels", enableTensorUKernels,
162+
cl::cat(category),
163+
cl::desc("Enable MLIR-based ukernels."));
159164
}
160165

161166
LogicalResult verify(mlir::Builder &builder) const {
@@ -329,6 +334,11 @@ class ROCMTargetBackend final : public TargetBackend {
329334
addConfigWavesPerEu(b.getContext(), options.wavesPerEu, configItems);
330335
}
331336

337+
if (options.enableTensorUKernels) {
338+
addConfig(kUKernelProviderName,
339+
IREE::Codegen::SymbolicUKernelProviderAttr::get(context));
340+
}
341+
332342
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
333343
b.getStringAttr("rocm"), b.getStringAttr(format),
334344
b.getDictionaryAttr(configItems));
@@ -365,7 +375,31 @@ class ROCMTargetBackend final : public TargetBackend {
365375
});
366376
}
367377
}
368-
buildLLVMGPUCodegenConfigurationPassPipeline(passManager);
378+
if (options.enableTensorUKernels) {
379+
if (auto attr = getGPUTargetAttr(targetAttr.getContext(), targetAttr)) {
380+
ROCM::ApplyBuiltinPDLPatternsPassOptions options;
381+
options.enableTensorUKernels = true;
382+
if (IREE::GPU::TargetChipAttr chip = attr.getChip()) {
383+
if (StringAttr sku = chip.getSku()) {
384+
options.targets.push_back(sku.str());
385+
}
386+
}
387+
options.targets.push_back(attr.getArch().str());
388+
OpPassManager &modulePassManager = passManager.nest<ModuleOp>();
389+
FunctionLikeNest(modulePassManager).addPass([&]() {
390+
return ROCM::createApplyBuiltinPDLPatternsPass(options);
391+
});
392+
}
393+
}
394+
buildLLVMGPUCodegenCommonConfigurationPassPipeline(passManager);
395+
OpPassManager &modulePassManager = passManager.nest<ModuleOp>();
396+
if (options.enableTensorUKernels) {
397+
modulePassManager.addPass(
398+
IREE::ROCM::createApplyBuiltinPDLPatternsDriverPass());
399+
}
400+
modulePassManager.addPass(createMaterializeTuningSpecsPass());
401+
modulePassManager.addPass(createMaterializeUserConfigsPass());
402+
modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass());
369403
}
370404

371405
void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,

compiler/plugins/target/ROCM/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ iree_lit_test_suite(
1818
"config_ukernel_argmax_gfx942.mlir",
1919
"config_ukernel_data_tiled_mma_gfx942.mlir",
2020
"default_tuning_specs_amdgpu.mlir",
21+
"enable_tensor_ukernels.mlir",
2122
"gpu_encoding_attrs.mlir",
2223
"lowering_strategy_from_tuning_spec.mlir",
2324
"ukernel_pipeline_transform.mlir",

compiler/plugins/target/ROCM/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ iree_lit_test_suite(
1818
"config_ukernel_argmax_gfx942.mlir"
1919
"config_ukernel_data_tiled_mma_gfx942.mlir"
2020
"default_tuning_specs_amdgpu.mlir"
21+
"enable_tensor_ukernels.mlir"
2122
"gpu_encoding_attrs.mlir"
2223
"lowering_strategy_from_tuning_spec.mlir"
2324
"ukernel_pipeline_transform.mlir"
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 \
2+
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(iree-hal-configure-target-executable-variants{target=rocm})))" \
3+
// RUN: --iree-hip-enable-tensor-ukernels \
4+
// RUN: --verify-diagnostics %s | FileCheck %s
5+
6+
// Make sure we can match and insert a tensor ukernel.
7+
8+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
9+
#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
10+
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
11+
#pipeline_layout = #hal.pipeline.layout<bindings = [
12+
#hal.pipeline.binding<storage_buffer>,
13+
#hal.pipeline.binding<storage_buffer>,
14+
#hal.pipeline.binding<storage_buffer>
15+
]>
16+
hal.executable public @main {
17+
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
18+
hal.executable.export public @matmul_f8 ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
19+
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
20+
hal.return %x, %y, %z : index, index, index
21+
}
22+
builtin.module {
23+
func.func @matmul_f8() {
24+
%cst = arith.constant 0.000000e+00 : f32
25+
%c0 = arith.constant 0 : index
26+
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1x128x4096xf8E4M3FNUZ>>
27+
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x4096xf8E4M3FNUZ>>
28+
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x128x1024xf32>>
29+
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [1, 128, 4096], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1x128x4096xf8E4M3FNUZ>> -> tensor<1x128x4096xf8E4M3FNUZ>
30+
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [1024, 4096], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<1024x4096xf8E4M3FNUZ>> -> tensor<1024x4096xf8E4M3FNUZ>
31+
%5 = tensor.empty() : tensor<1x128x1024xf32>
32+
%6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x128x1024xf32>) -> tensor<1x128x1024xf32>
33+
%7 = linalg.generic {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<1x128x4096xf8E4M3FNUZ>, tensor<1024x4096xf8E4M3FNUZ>) outs(%6 : tensor<1x128x1024xf32>) {
34+
^bb0(%in: f8E4M3FNUZ, %in_4: f8E4M3FNUZ, %out: f32):
35+
%12 = arith.extf %in : f8E4M3FNUZ to f32
36+
%13 = arith.extf %in_4 : f8E4M3FNUZ to f32
37+
%14 = arith.mulf %12, %13 : f32
38+
%15 = arith.addf %out, %14 : f32
39+
linalg.yield %15 : f32
40+
} -> tensor<1x128x1024xf32>
41+
iree_tensor_ext.dispatch.tensor.store %7, %2, offsets = [0, 0, 0], sizes = [1, 128, 1024], strides = [1, 1, 1] : tensor<1x128x1024xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<1x128x1024xf32>>
42+
return
43+
}
44+
}
45+
}
46+
}
47+
// CHECK: #[[TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [512, 1, 1] subgroup_size = 64
48+
// CHECK: func.func @matmul_f8
49+
// CHECK-SAME: translation_info = #[[TRANSLATION]]
50+
// CHECK: linalg.generic
51+
// CHECK-SAME: iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"pingpong_medium_f8_expanded", tensor>
52+
// CHECK-SAME: lowering_config = #iree_gpu.lowering_config
53+
// CHECK: util.func private @pingpong_medium_f8_expanded
54+
// CHECK: iree_codegen.inner_tiled

compiler/src/iree/compiler/Codegen/Common/LowerUKernelDescriptors.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,15 @@ processUKernelKind(Operation *root, IREE::Codegen::UKernelArgumentKind kind) {
261261
FailureOr<Operation *> maybeTargetFunction = provider.getMLIRUKernel(
262262
name, targetAttr.getConfiguration(), annotationSite);
263263
if (failed(maybeTargetFunction) || !*maybeTargetFunction) {
264-
return op->emitOpError()
265-
<< "failed to retrieve a uKernel with name " << name;
264+
// If not found at the annotation site, look in the first ModuleOp
265+
// parent as well.
266+
auto moduleParent = op->getParentOfType<ModuleOp>();
267+
maybeTargetFunction = provider.getMLIRUKernel(
268+
name, targetAttr.getConfiguration(), moduleParent);
269+
if (failed(maybeTargetFunction) || !*maybeTargetFunction) {
270+
return op->emitOpError()
271+
<< "failed to retrieve a uKernel with name " << name;
272+
}
266273
}
267274
auto targetFunction = dyn_cast<FunctionOpInterface>(*maybeTargetFunction);
268275
if (!targetFunction) {

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
static const char kTranslationInfoAttrName[] = "translation_info";
2727
static const char kCompilationInfoAttrName[] = "compilation_info";
2828
static const char kRootOpInfoAttrName[] = "root_op";
29-
static const char kUKernelProviderName[] = "iree_codegen.ukernel_provider";
3029
static const char kUKernelDescriptorName[] = "iree_codegen.ukernel";
3130

3231
namespace mlir::iree_compiler {

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ constexpr StringLiteral kTuningSpecEntrypointAttrName =
5151
constexpr StringLiteral kSerializedTuningSpecAttrName =
5252
"iree_codegen.tuning_spec_mlirbc";
5353
constexpr StringLiteral kKernelConfigSpecName = "__kernel_config";
54+
constexpr StringLiteral kUKernelProviderName = "iree_codegen.ukernel_provider";
5455

5556
//===----------------------------------------------------------------------===//
5657
// Helpers for getting/setting iree_codegen.translation_info attribute on a

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
430430
//
431431
// In the future there may be cases where we want the custom strategy run at
432432
// later points in the pipeline.
433+
funcPassManager.addPass(createLowerTensorUKernelsPass());
433434
funcPassManager.addPass(createLoweringConfigInterpreterPass());
434435
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
435436
funcPassManager.addPass(createCSEPass());
@@ -1276,7 +1277,7 @@ void addGPUTransformDialectPasses(OpPassManager &funcPassManager,
12761277
// Common Pass Pipelines
12771278
//===----------------------------------------------------------------------===//
12781279

1279-
static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
1280+
static void buildLLVMGPUCodegenCommonConfigurationPassPipelineImpl(
12801281
OpPassManager &modulePassManager) {
12811282
{
12821283
FunctionLikeNest funcPassManager(modulePassManager);
@@ -1302,6 +1303,17 @@ static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
13021303
funcPassManager.addPass(createConfigTrackingCanonicalizerPass);
13031304
funcPassManager.addPass(createCSEPass);
13041305
}
1306+
}
1307+
1308+
void buildLLVMGPUCodegenCommonConfigurationPassPipeline(
1309+
OpPassManager &variantPassManager) {
1310+
buildLLVMGPUCodegenCommonConfigurationPassPipelineImpl(
1311+
variantPassManager.nest<ModuleOp>());
1312+
}
1313+
1314+
static void buildLLVMGPUCodegenConfigurationPassPipelineImpl(
1315+
OpPassManager &modulePassManager) {
1316+
buildLLVMGPUCodegenCommonConfigurationPassPipelineImpl(modulePassManager);
13051317
modulePassManager.addPass(createMaterializeTuningSpecsPass());
13061318
modulePassManager.addPass(createMaterializeUserConfigsPass());
13071319
modulePassManager.addPass(createLLVMGPUSelectLoweringStrategyPass());

compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ void addGPUDefaultPassPipeline(OpPassManager &funcPassManager,
8181
/// Pass pipeline to lower IREE HAL executables without tiling and distribution.
8282
void addGPUBaseLoweringPassPipeline(OpPassManager &pm);
8383

84+
/// Populates the common passes needed to preprocess and select the translation
85+
/// strategy.
86+
void buildLLVMGPUCodegenCommonConfigurationPassPipeline(
87+
OpPassManager &variantPassManagery);
88+
8489
/// Populates passes needed to preprocess and select the translation strategy.
8590
void buildLLVMGPUCodegenConfigurationPassPipeline(
8691
OpPassManager &variantPassManagery);

0 commit comments

Comments
 (0)