Skip to content

Commit 3aa9c80

Browse files
authored
[ROCM][DT] Add encoding specialization infra for data-tiled ukernels (iree-org#21914)
Passes the `UKernelProvider` to the encoding resolver so it can be used to choose the data layouts for specialization and materialization. The `UKernelProviderInterface` gets a new `getDataLayoutForUKernel` which is responsible for returning a data layout attribute based on the encoding and target configuration. Signed-off-by: Jorn Tuyls <[email protected]>
1 parent 503621f commit 3aa9c80

File tree

11 files changed

+246
-18
lines changed

11 files changed

+246
-18
lines changed

compiler/plugins/target/ROCM/Dialect/ROCM/IR/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ iree_compiler_cc_library(
6464
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
6565
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms:GPUTransforms",
6666
"//compiler/src/iree/compiler/Codegen/Utils",
67+
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
6768
"//compiler/src/iree/compiler/Dialect/HAL/IR",
6869
"//compiler/src/iree/compiler/Dialect/Util/IR",
6970
"//compiler/src/iree/compiler/Utils",
@@ -78,6 +79,7 @@ iree_compiler_cc_library(
7879
"@llvm-project//mlir:GPUUtils",
7980
"@llvm-project//mlir:IR",
8081
"@llvm-project//mlir:LinalgDialect",
82+
"@llvm-project//mlir:LinalgInterfaces",
8183
"@llvm-project//mlir:Parser",
8284
"@llvm-project//mlir:Support",
8385
],

compiler/plugins/target/ROCM/Dialect/ROCM/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ iree_cc_library(
3838
MLIRGPUUtils
3939
MLIRIR
4040
MLIRLinalgDialect
41+
MLIRLinalgInterfacesIncGenLib
4142
MLIRParser
4243
MLIRSupport
4344
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
4445
iree::compiler::Codegen::Dialect::Codegen::Utils
4546
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
4647
iree::compiler::Codegen::Dialect::GPU::Transforms::GPUTransforms
4748
iree::compiler::Codegen::Utils
49+
iree::compiler::Dialect::Encoding::IR
4850
iree::compiler::Dialect::HAL::IR
4951
iree::compiler::Dialect::Util::IR
5052
iree::compiler::Utils

compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
1212
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
1313
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
14+
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
1415
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
1516
#include "llvm/ADT/TypeSwitch.h"
1617
#include "llvm/ExecutionEngine/ExecutionEngine.h"
@@ -20,6 +21,7 @@
2021
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2122
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2223
#include "mlir/Dialect/Linalg/IR/Linalg.h"
24+
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
2325
#include "mlir/Dialect/Utils/IndexingUtils.h"
2426
#include "mlir/IR/AsmState.h"
2527
#include "mlir/IR/Attributes.h"
@@ -463,6 +465,48 @@ std::optional<LogicalResult> UKernelProviderAttr::createAndReplaceWithUkernelOp(
463465
return std::nullopt;
464466
}
465467

468+
//===---------------------------------------------------------------------===//
469+
// rocm.tensor_ukernel_provider
470+
//===---------------------------------------------------------------------===//
471+
472+
FailureOr<Operation *>
473+
TensorUKernelProviderAttr::getMLIRUKernel(StringRef name, DictionaryAttr,
474+
Operation *annotationSite) const {
475+
auto *symbolTableOp = SymbolTable::getNearestSymbolTable(annotationSite);
476+
SymbolTable symbolTable(symbolTableOp);
477+
return symbolTable.lookup(name);
478+
}
479+
480+
Attribute TensorUKernelProviderAttr::getDataLayoutForUKernel(
481+
Attribute encoding, DictionaryAttr targetConfiguration) const {
482+
auto encodingAttr =
483+
dyn_cast_if_present<IREE::Encoding::EncodingAttr>(encoding);
484+
if (!encodingAttr) {
485+
return {};
486+
}
487+
IREE::GPU::TargetAttr targetAttr = getGPUTargetAttr(targetConfiguration);
488+
if (!targetAttr || targetAttr.getArch() != "gfx942") {
489+
return {};
490+
}
491+
ArrayAttr indexingMapsAttr = encodingAttr.getUserIndexingMaps();
492+
if (!indexingMapsAttr) {
493+
return {};
494+
}
495+
if (failed(linalg::inferContractionDims(encodingAttr.getRootMaps()))) {
496+
return {};
497+
}
498+
SmallVector<Type> types = encodingAttr.getElementTypesArray();
499+
Type f16 = Float16Type::get(encoding.getContext());
500+
Type f32 = Float32Type::get(encoding.getContext());
501+
if (types.size() != 3 || types[0] != f16 || types[1] != f16 ||
502+
types[2] != f32) {
503+
return {};
504+
}
505+
return IREE::GPU::DataTiledMMAAttr::get(
506+
encoding.getContext(), IREE::GPU::MMAIntrinsic::MFMA_F32_16x16x16_F16, 8,
507+
2, 4, 4, 1);
508+
}
509+
466510
//===----------------------------------------------------------------------===//
467511
// Attribute Registration
468512
//===----------------------------------------------------------------------===//

compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,23 @@ def ROCM_UKernelProviderAttr :
5353
let assemblyFormat = [{}];
5454
}
5555

56+
//===---------------------------------------------------------------------===//
57+
// rocm.tensor_ukernel_provider
58+
//===---------------------------------------------------------------------===//
59+
60+
def ROCM_TensorUKernelProviderAttr :
61+
AttrDef<ROCM_Dialect, "TensorUKernelProvider", [
62+
DeclareAttrInterfaceMethods<IREECodegen_UKernelProviderInterface, [
63+
"getDataLayoutForUKernel",
64+
"getMLIRUKernel",
65+
]>
66+
]> {
67+
let mnemonic = "tensor_ukernel_provider";
68+
let summary = [{
69+
An attribute that provides context specific tensor ukernel implementations for ROCM.
70+
}];
71+
let parameters = (ins);
72+
let assemblyFormat = [{}];
73+
}
74+
5675
#endif // IREE_PLUGINS_TARGET_ROCM_DIALECT_ROCMATTRS

compiler/plugins/target/ROCM/ROCMTarget.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ class ROCMTargetBackend final : public TargetBackend {
354354

355355
if (options.enableTensorUKernels) {
356356
addConfig(kUKernelProviderName,
357-
IREE::Codegen::SymbolicUKernelProviderAttr::get(context));
357+
IREE::ROCM::TensorUKernelProviderAttr::get(context));
358358
}
359359

360360
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(

compiler/plugins/target/ROCM/test/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ iree_lit_test_suite(
2020
"default_tuning_specs_amdgpu.mlir",
2121
"enable_tensor_ukernels.mlir",
2222
"gpu_encoding_attrs.mlir",
23+
"lower_rocm_tensor_ukernel_descriptor.mlir",
2324
"lower_rocm_ukernel_descriptor.mlir",
2425
"lowering_strategy_from_tuning_spec.mlir",
26+
"materialize_encoding_ukernel_gfx942.mlir",
2527
"ukernel_pipeline_transform.mlir",
2628
],
2729
cfg = "//compiler:lit.cfg.py",

compiler/plugins/target/ROCM/test/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ iree_lit_test_suite(
2020
"default_tuning_specs_amdgpu.mlir"
2121
"enable_tensor_ukernels.mlir"
2222
"gpu_encoding_attrs.mlir"
23+
"lower_rocm_tensor_ukernel_descriptor.mlir"
2324
"lower_rocm_ukernel_descriptor.mlir"
2425
"lowering_strategy_from_tuning_spec.mlir"
26+
"materialize_encoding_ukernel_gfx942.mlir"
2527
"ukernel_pipeline_transform.mlir"
2628
TOOLS
2729
FileCheck
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: iree-opt --iree-codegen-lower-tensor-ukernels --split-input-file --verify-diagnostics %s | FileCheck %s
2+
3+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree_codegen.ukernel_provider = #rocm.tensor_ukernel_provider}>
4+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
5+
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
6+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
7+
module attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
8+
func.func private @ukernel_impl(tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
9+
func.func @test(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> {
10+
%0 = call @ukernel_impl(%arg0, %arg1, %arg2) : (tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
11+
return %0 : tensor<16x16xf32>
12+
}
13+
func.func @replace_generic_with_ukernel_impl(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xf32>) -> tensor<16x16xf32> {
14+
%cst = arith.constant 0.000000e+00 : f32
15+
%0 = tensor.empty() : tensor<16x16xf32>
16+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x16xf32>) -> tensor<16x16xf32>
17+
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x32xf32>, tensor<16x32xf32>) outs(%1 : tensor<16x16xf32>) attrs = {iree_codegen.ukernel = #iree_codegen.ukernel_descriptor<"test", tensor>} {
18+
^bb0(%in: f32, %in_0: f32, %out: f32):
19+
%3 = arith.mulf %in, %in_0 : f32
20+
%4 = arith.addf %out, %3 : f32
21+
linalg.yield %4 : f32
22+
} -> tensor<16x16xf32>
23+
return %2 : tensor<16x16xf32>
24+
}
25+
}
26+
// CHECK-LABEL: @ukernel_impl
27+
// CHECK-LABEL: @replace_generic_with_ukernel_impl
28+
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<16x32xf32>
29+
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<16x32xf32>
30+
// CHECK-NOT: linalg.generic
31+
// CHECK: %[[OUT:.+]] = linalg.fill
32+
// CHECK: %[[CALL:.+]] = call @ukernel_impl(%[[LHS]], %[[RHS]], %[[OUT]]) : (tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x16xf32>) -> tensor<16x16xf32>
33+
// CHECK: return %[[CALL]]
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-materialize-device-encoding))" --split-input-file %s | FileCheck %s
2+
3+
// Note the ukernel provider being specified in the executable target. This should be used to determine the data tiling.
4+
5+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
6+
abi = "hip",
7+
iree.encoding.resolver = #iree_gpu.gpu_encoding_resolver<>,
8+
iree_codegen.target_info = #iree_gpu.target<
9+
arch = "gfx942",
10+
features = "",
11+
wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8,
12+
storage = b64|b32|b16|b8,
13+
subgroup = shuffle|arithmetic,
14+
dot = dp4xi8toi32,
15+
mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>,
16+
<MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>,
17+
<MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>,
18+
<MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>,
19+
<MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>,
20+
<MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>,
21+
<MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>
22+
],
23+
subgroup_size_choices = [64],
24+
max_workgroup_sizes = [1024, 1024, 1024],
25+
max_thread_count_per_workgroup = 1024,
26+
max_workgroup_memory_bytes = 65536,
27+
max_workgroup_counts = [2147483647, 2147483647, 2147483647],
28+
max_load_instruction_bits = 128,
29+
simds_per_wgp = 4,
30+
vgpr_space_bits = 16384>
31+
>,
32+
iree_codegen.ukernel_provider = #rocm.tensor_ukernel_provider,
33+
ukernels = "none"
34+
}>
35+
36+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
37+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
38+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
39+
#encoding_lhs = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
40+
#encoding_rhs = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
41+
#encoding_result = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], iteration_sizes = [?, ?, ?]>
42+
#pipeline_layout_3 = #hal.pipeline.layout<constants = 3, bindings = [
43+
#hal.pipeline.binding<storage_buffer>,
44+
#hal.pipeline.binding<storage_buffer>,
45+
#hal.pipeline.binding<storage_buffer>
46+
]>
47+
48+
func.func @matmul_lowering_ukernel_provider() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
49+
%c0 = arith.constant 0 : index
50+
%M = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(0) : index
51+
%N = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(1) : index
52+
%K = hal.interface.constant.load layout(#pipeline_layout_3) ordinal(2) : index
53+
%0 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(0) alignment(64) offset(%c0)
54+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf16, #encoding_lhs>>{%M, %K}
55+
%1 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(1) alignment(64) offset(%c0)
56+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf16, #encoding_rhs>>{%K, %N}
57+
%2 = hal.interface.binding.subspan layout(#pipeline_layout_3) binding(2) alignment(64) offset(%c0)
58+
: !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
59+
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [%M, %K], strides = [1, 1]
60+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf16, #encoding_lhs>>{%M, %K}
61+
-> tensor<?x?xf16, #encoding_lhs>
62+
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0, 0], sizes = [%K, %N], strides = [1, 1]
63+
: !iree_tensor_ext.dispatch.tensor<readonly:tensor<?x?xf16, #encoding_rhs>>{%K, %N}
64+
-> tensor<?x?xf16, #encoding_rhs>
65+
%5 = iree_tensor_ext.dispatch.tensor.load %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
66+
: !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
67+
-> tensor<?x?xf32, #encoding_result>
68+
%6 = linalg.matmul
69+
ins(%3, %4 : tensor<?x?xf16, #encoding_lhs>,
70+
tensor<?x?xf16, #encoding_rhs>)
71+
outs(%5 : tensor<?x?xf32, #encoding_result>)
72+
-> tensor<?x?xf32, #encoding_result>
73+
iree_tensor_ext.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [%M, %N], strides = [1, 1]
74+
: tensor<?x?xf32, #encoding_result>
75+
-> !iree_tensor_ext.dispatch.tensor<readwrite:tensor<?x?xf32, #encoding_result>>{%M, %N}
76+
return
77+
}
78+
// CHECK-LABEL: matmul_lowering_ukernel_provider
79+
// CHECK: iree_codegen.inner_tiled
80+
// CHECK-SAME: iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>]
81+
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, intrinsics_m = 8, subgroups_m = 2, intrinsics_n = 4, subgroups_n = 4>

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,25 @@ def IREECodegen_UKernelProviderInterface :
575575
return failure();
576576
}]
577577
>,
578+
InterfaceMethod<
579+
/*desc=*/[{
580+
Returns a data layout attribute for the provided |encoding| and
581+
|target_configuration|.
582+
583+
The return type is deliberately '::mlir::Attribute' to accommodate
584+
various data layout specifications. Callers, such as encoding
585+
resolvers, are expected to handle a range of possible attributes
586+
and gracefully manage cases where an unsupported attribute is returned.
587+
}],
588+
/*retTy=*/"::mlir::Attribute",
589+
/*methodName=*/"getDataLayoutForUKernel",
590+
/*args=*/(ins "::mlir::Attribute":$encoding,
591+
"::mlir::DictionaryAttr":$target_configuration),
592+
/*methodBody=*/"",
593+
/*defaultImplementation=*/[{
594+
return {};
595+
}]
596+
>,
578597
];
579598
}
580599

0 commit comments

Comments
 (0)