Skip to content

Commit a9c7ec1

Browse files
authored
[Util][GPU] Add TiedOpInterface implementation for iree_gpu.multi_mma (#18626)
This PR is a part of what was originally #18608. The PR implements the TiedOpInterface for the iree_gpu.multi_mma op. This is a temporary solution to having multi_mma ops before dispatch workgroup creation, and is only needed right now because we rely on early materialization. This will enable e2e matmul tests with GPU data tiling while it is still being developed, and this change can be dropped once we switch to late materialization. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent b7ac442 commit a9c7ec1

File tree

4 files changed

+67
-0
lines changed

4 files changed

+67
-0
lines changed

compiler/src/iree/compiler/DispatchCreation/test/convert_region_to_workgroups.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,40 @@ util.func public @foo(%argA: tensor<?x?xf32>, %argB: tensor<5x10xf32>, %argC: te
4747
// CHECK: util.return %[[r0]], %[[r1]]
4848
util.return %r0, %r1 : tensor<?x?xf32>, tensor<5x11xf32>
4949
}
50+
51+
// -----
52+
53+
// TODO(Max191): Remove this test once GPU data tiling stops using early
54+
// materialization.
55+
util.func public @multi_mma(
56+
%arg0: tensor<4x16x8x4x16x2x4xf16>,
57+
%arg1: tensor<4x16x4x2x4x16x2x4xf16>,
58+
%arg2: tensor<4x4x8x4x2x4x16x4xf32>) -> (tensor<4x4x8x4x2x4x16x4xf32>) {
59+
%9 = flow.dispatch.region -> (tensor<4x4x8x4x2x4x16x4xf32>) {
60+
%13 = iree_gpu.multi_mma %arg0, %arg1, %arg2 {
61+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
62+
affine_map<(d0, d1, d2) -> (d1, d2)>,
63+
affine_map<(d0, d1, d2) -> (d0, d1)>],
64+
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
65+
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, unroll_m = 8, unroll_n = 2, unroll_n_to_subgroups = 4, unroll_k = 2>}
66+
: tensor<4x16x8x4x16x2x4xf16>, tensor<4x16x4x2x4x16x2x4xf16> into tensor<4x4x8x4x2x4x16x4xf32>
67+
flow.return %13 : tensor<4x4x8x4x2x4x16x4xf32>
68+
}
69+
util.return %9 : tensor<4x4x8x4x2x4x16x4xf32>
70+
}
71+
72+
// CHECK-LABEL: util.func public @multi_mma(
73+
// CHECK: %[[arg0:.*]]: tensor<4x16x8x4x16x2x4xf16>, %[[arg1:.*]]: tensor<4x16x4x2x4x16x2x4xf16>, %[[arg2:.*]]: tensor<4x4x8x4x2x4x16x4xf32>
74+
// CHECK: %[[r0:.*]] = flow.dispatch.workgroups(%[[arg0]], %[[arg1]], %[[arg2]])
75+
// CHECK-SAME: : (tensor<4x16x8x4x16x2x4xf16>, tensor<4x16x4x2x4x16x2x4xf16>, tensor<4x4x8x4x2x4x16x4xf32>)
76+
// CHECK-NEXT: (%[[arg3:.*]]: !flow.dispatch.tensor<readonly:tensor<4x16x8x4x16x2x4xf16>>,
77+
// CHECK-SAME: %[[arg4:.*]]: !flow.dispatch.tensor<readonly:tensor<4x16x4x2x4x16x2x4xf16>>,
78+
// CHECK-SAME: %[[arg5:.*]]: !flow.dispatch.tensor<readwrite:tensor<4x4x8x4x2x4x16x4xf32>>)
79+
// CHECK-DAG: %[[loadLHS:.*]] = flow.dispatch.tensor.load %[[arg3]]
80+
// CHECK-DAG: %[[loadRHS:.*]] = flow.dispatch.tensor.load %[[arg4]]
81+
// CHECK-DAG: %[[loadACC:.*]] = flow.dispatch.tensor.load %[[arg5]]
82+
// CHECK: %[[MULTI_MMA:.*]] = iree_gpu.multi_mma %[[loadLHS]], %[[loadRHS]], %[[loadACC]]
83+
// CHECK: flow.dispatch.tensor.store %[[MULTI_MMA]], %[[arg5]]
84+
// CHECK: flow.return
85+
// CHECK: }
86+
// CHECK: util.return %[[r0]]

compiler/src/iree/compiler/ExternalInterfaces/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ iree_compiler_cc_library(
2727
"UtilExternalModels.h",
2828
],
2929
deps = [
30+
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
3031
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
3132
"//compiler/src/iree/compiler/Dialect/Flow/IR",
3233
"//compiler/src/iree/compiler/Dialect/HAL/IR",

compiler/src/iree/compiler/ExternalInterfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ iree_cc_library(
3232
MLIRMLProgramDialect
3333
MLIRTensorDialect
3434
MLIRValueBoundsOpInterface
35+
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
3536
iree::compiler::Dialect::Encoding::IR
3637
iree::compiler::Dialect::Flow::IR
3738
iree::compiler::Dialect::HAL::IR

compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "iree/compiler/ExternalInterfaces/UtilExternalModels.h"
88

9+
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
10+
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
911
#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
1012
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
1113
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
@@ -168,6 +170,27 @@ struct LinalgOpTiedOpInterfaceHelper {
168170
}
169171
};
170172

173+
// TODO(Max191): Remove this interface once GPU data tiling stops using early
174+
// materialization. This only exists for handling multi_mma ops before dispatch
175+
// workgroups are created, which only happens with early materialization.
176+
struct MultiMmaOpTiedOpInterface
177+
: public IREE::Util::TiedOpInterface::ExternalModel<
178+
MultiMmaOpTiedOpInterface, IREE::GPU::MultiMmaOp> {
179+
Value getTiedResult(Operation *op, unsigned resultIndex) const {
180+
auto linalgOp = cast<IREE::GPU::MultiMmaOp>(op);
181+
return IREE::Util::TiedOpInterface::findTiedBaseValue(linalgOp.getAcc());
182+
}
183+
184+
::std::optional<unsigned>
185+
getTiedResultOperandIndex(Operation *op, unsigned resultIndex) const {
186+
return {2}; // acc
187+
}
188+
189+
SmallVector<int64_t> getTiedResultOperandIndices(Operation *op) const {
190+
return {2}; // acc
191+
}
192+
};
193+
171194
//===----------------------------------------------------------------------===//
172195
// HoistableOpInterface
173196
//===----------------------------------------------------------------------===//
@@ -289,6 +312,11 @@ void registerUtilExternalModels(DialectRegistry &registry) {
289312
*context);
290313
});
291314

315+
registry.addExtension(+[](MLIRContext *context,
316+
IREE::GPU::IREEGPUDialect *dialect) {
317+
IREE::GPU::MultiMmaOp::attachInterface<MultiMmaOpTiedOpInterface>(*context);
318+
});
319+
292320
registry.addExtension(
293321
+[](MLIRContext *context, linalg::LinalgDialect *dialect) {
294322
// Register all Linalg structured ops. `LinalgOp` is an interface and it

0 commit comments

Comments
 (0)