Skip to content

Commit 93bb9e4

Browse files
authored
[GPU] Add col_major optional attribute to MMAAttr (#19860)
For MMA intrinsics with symmetric register layouts we have the freedom to choose between row major and column major output layouts by simply swapping the input operands. This adds a flag to `MMAAttr` that allows specifying whether the result should be column major.
1 parent 6d25b91 commit 93bb9e4

File tree

6 files changed

+135
-7
lines changed

6 files changed

+135
-7
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,21 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
345345
return {};
346346
}
347347

348+
MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
349+
MMAFragment fragment,
350+
bool colMajor) {
351+
MMASingleSubgroupLayout baseLayout =
352+
getSingleSubgroupLayout(intrinsic, fragment);
353+
assert(baseLayout.element.size() == 2 && "expected 2d layout");
354+
if (colMajor) {
355+
std::swap(baseLayout.element[0], baseLayout.element[1]);
356+
std::swap(baseLayout.thread[0], baseLayout.thread[1]);
357+
std::swap(baseLayout.outer[0], baseLayout.outer[1]);
358+
std::swap(baseLayout.tstrides[0], baseLayout.tstrides[1]);
359+
}
360+
return baseLayout;
361+
}
362+
348363
// Struct describing the shape of a MMA operation, but not the detailed layout.
349364
struct OpaqueMmaLayout {
350365
int64_t mSize = 0;
@@ -388,7 +403,11 @@ static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
388403
MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
389404
MMAFragment fragment) {
390405
if (auto mmaAttr = dyn_cast<MMAAttr>(mmaKind)) {
391-
return getSingleSubgroupLayout(mmaAttr.getIntrinsic(), fragment);
406+
// |colMajor| indicates that the accumulator layout should be returned
407+
// column major.
408+
return getSingleSubgroupLayout(mmaAttr.getIntrinsic(), fragment,
409+
fragment == MMAFragment::Acc &&
410+
mmaAttr.getColMajor());
392411
}
393412
if (auto vmmaAttr = dyn_cast<VirtualMMAAttr>(mmaKind)) {
394413
return getSingleSubgroupLayout(vmmaAttr.getIntrinsic(), fragment);
@@ -401,6 +420,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
401420
// MMA Attributes
402421
//===----------------------------------------------------------------------===//
403422

423+
MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
424+
return Base::get(context, type, /*colMajor=*/false);
425+
}
426+
404427
std::tuple<Type, Type, Type> MMAAttr::getABCElementTypes() const {
405428
return IREE::GPU::getABCElementTypes(getContext(), getIntrinsic());
406429
}
@@ -468,7 +491,7 @@ SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
468491

469492
static Value createMmaOp(OpBuilder &builder, Location loc,
470493
MMAIntrinsic intrinsic, Type resultType, Value lhs,
471-
Value rhs, Value acc) {
494+
Value rhs, Value acc, bool colMajor = false) {
472495
auto getVecOrSingleElem = [&](Value vec) -> Value {
473496
bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
474497
return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
@@ -478,6 +501,13 @@ static Value createMmaOp(OpBuilder &builder, Location loc,
478501
// MFMA intrinsics want single-element operands of element type, not vector.
479502
lhs = getVecOrSingleElem(lhs);
480503
rhs = getVecOrSingleElem(rhs);
504+
505+
// Because the thread layout of the lhs and rhs are transpositions of one
506+
// another for all MFMA variants, to produce a column major result we can
507+
// simply swap the operands to the MFMA.
508+
if (colMajor) {
509+
std::swap(lhs, rhs);
510+
}
481511
return builder
482512
.create<amdgpu::MFMAOp>(loc, resultType, layout.mSize, layout.nSize,
483513
layout.kSize, getBlockSize(intrinsic), lhs, rhs,
@@ -507,7 +537,7 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
507537
return failure();
508538
}
509539
if (Value value = createMmaOp(builder, loc, getIntrinsic(), resultType, lhs,
510-
rhs, acc)) {
540+
rhs, acc, getColMajor())) {
511541
return value;
512542
}
513543
return failure();
@@ -592,8 +622,8 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides(
592622
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
593623
SmallVector<OpFoldResult> &strides) const {
594624

595-
MMASingleSubgroupLayout subgroupLayout =
596-
getSingleSubgroupLayout(getIntrinsic(), fragment);
625+
MMASingleSubgroupLayout subgroupLayout = getSingleSubgroupLayout(
626+
getIntrinsic(), fragment, fragment == MMAFragment::Acc && getColMajor());
597627
SmallVector<OpFoldResult> canonicalOffsets;
598628
SmallVector<OpFoldResult> canonicalSizes;
599629
if (failed(populateCanonicalOffsetsSizesAndStrides(

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ int64_t getKSize(MMAIntrinsic intrinsic);
7272
MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
7373
MMAFragment fragment);
7474

75+
MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
76+
MMAFragment fragment,
77+
bool colMajor);
78+
7579
MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
7680
MMAFragment fragment);
7781

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,25 @@ def IREEGPU_MMAAttr : AttrDef<IREEGPU_Dialect, "MMA", [
150150
The |intrinsic| field specifies which particular MMA intrinsic this refers
151151
to, with each intrinsic implicating a specific MNK shape and operand types.
152152
See IREEGPUEnums.td for the definition of the intrinsics.
153+
154+
If set to true, |col_major| indicates that the result should be produced
155+
column major. This is equivalent to instead computing:
156+
157+
```
158+
C^T += B^T x A^T
159+
```
153160
}];
154161

155162
let parameters = (ins
156-
EnumParameter<IREEGPU_MMAIntrinsic>:$intrinsic
163+
EnumParameter<IREEGPU_MMAIntrinsic>:$intrinsic,
164+
DefaultValuedParameter<"bool", "false">:$col_major
157165
);
158166

159-
let assemblyFormat = "`<` params `>`";
167+
let assemblyFormat = "`<` $intrinsic (`,` `col_major` `=` $col_major^)? `>`";
168+
169+
let builders = [
170+
AttrBuilder<(ins "MMAIntrinsic":$intrinsic)>
171+
];
160172

161173
let extraClassDeclaration = [{
162174
int64_t getBlockSize() const;

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/iree_gpu_attrs.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ module {
1818
// CHECK-LABEL: func @test_mfma_f16_32x32x8_f32
1919
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
2020

21+
module {
22+
func.func @test_col_major_mfma_f16_16x16x16_f32() attributes {
23+
mma_types = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>} {
24+
return
25+
}
26+
}
27+
// CHECK-LABEL: func @test_col_major_mfma_f16_16x16x16_f32
28+
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16, col_major = true>
29+
2130
module {
2231
func.func @test_WMMAR3_f16_16x16x16_f32() attributes {
2332
mma_types = #iree_gpu.mma_layout<WMMAR3_F32_16x16x16_F16>} {

compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/lower_multi_mma.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,40 @@ module attributes { transform.with_named_sequence } {
6868

6969
// -----
7070

71+
#contraction_accesses = [
72+
affine_map<() -> ()>,
73+
affine_map<() -> ()>,
74+
affine_map<() -> ()>
75+
]
76+
func.func @lower_col_major_multi_mma_mfma_32x32x8(%lhs: vector<4xf16>, %rhs: vector<4xf16>, %acc: vector<16xf32>) -> vector<16xf32> {
77+
%0 = iree_gpu.multi_mma %lhs, %rhs, %acc {
78+
indexing_maps = #contraction_accesses,
79+
iterator_types = [],
80+
kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16, col_major = true>
81+
} : vector<4xf16>, vector<4xf16> into vector<16xf32>
82+
return %0 : vector<16xf32>
83+
}
84+
85+
module attributes { transform.with_named_sequence } {
86+
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
87+
%func = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
88+
transform.apply_patterns to %func {
89+
transform.apply_patterns.iree.lower_multi_mma
90+
} : !transform.any_op
91+
transform.yield
92+
}
93+
}
94+
95+
// CHECK-LABEL: func @lower_col_major_multi_mma_mfma_32x32x8
96+
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: vector<4xf16>
97+
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: vector<4xf16>
98+
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]: vector<16xf32>
99+
// CHECK: amdgpu.mfma %[[RHS]] * %[[LHS]] + %[[ACC]]
100+
// CHECK-SAME: blocks = 1 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32
101+
// CHECK-SAME: blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
102+
103+
// -----
104+
71105
#contraction_accesses = [
72106
affine_map<() -> ()>,
73107
affine_map<() -> ()>,

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,45 @@ module {
107107

108108
// -----
109109

110+
#contraction_accesses = [
111+
affine_map<(i, j, k) -> (i, k)>,
112+
affine_map<(i, j, k) -> (k, j)>,
113+
affine_map<(i, j, k) -> (i, j)>
114+
]
115+
module {
116+
func.func @col_major_matmul_32x32x8(%arg0: tensor<2x8x32x8xf16>, %arg1: tensor<8x2x32x8xf16>, %arg2: tensor<2x2x32x4x8xf32>) -> tensor<2x2x32x4x8xf32> {
117+
%mm = iree_gpu.multi_mma %arg0, %arg1, %arg2 {
118+
indexing_maps = #contraction_accesses,
119+
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
120+
kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16, col_major = true>,
121+
rhs_permutation = array<i64: 1, 0>
122+
} : tensor<2x8x32x8xf16>, tensor<8x2x32x8xf16> into tensor<2x2x32x4x8xf32>
123+
return %mm : tensor<2x2x32x4x8xf32>
124+
}
125+
}
126+
127+
// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
128+
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
129+
// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
130+
131+
// CHECK-LABEL: func @col_major_matmul_32x32x8
132+
// CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x8x32x8xf16>
133+
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<8x2x32x8xf16>
134+
// CHECK: scf.forall (%[[LANEID:.+]]) in (64) shared_outs(%[[ACC:.+]] = {{.*}}) -> (tensor<2x2x32x4x8xf32>)
135+
// CHECK-DAG: %[[ID:.+]]:3 = affine.delinearize_index %[[LANEID]] into (2, 32)
136+
// CHECK-DAG: %[[IDY:.+]] = affine.linearize_index disjoint [%[[ID]]#1, %c0] by (2, 4)
137+
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[ID]]#2, %[[IDY]]] [2, 8, 1, 4]
138+
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[ID]]#2, %[[IDY]]] [8, 2, 1, 4]
139+
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC]][0, 0, %[[ID]]#2, 0, %[[IDY]]] [2, 2, 1, 4, 4]
140+
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
141+
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
142+
// CHECK-SAME: kind = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16, col_major = true>
143+
// CHECK-SAME: : tensor<2x8x1x4xf16>, tensor<8x2x1x4xf16> into tensor<2x2x1x4x4xf32>
144+
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC]][0, 0, %[[ID]]#2, 0, %[[IDY]]] [2, 2, 1, 4, 4]
145+
// CHECK: mapping = [#iree_gpu.lane_id<0>]
146+
147+
// -----
148+
110149
#contraction_accesses = [
111150
affine_map<(i, j, k) -> (i, k)>,
112151
affine_map<(i, j, k) -> (k, j)>,

0 commit comments

Comments
 (0)