Skip to content

Commit 2a2bd06

Browse files
authored
More IREEGPUAttrs.cpp cleanups (#19142)
Two things in this PR: 1. Make a big switch statement more concise. 2. Currently, `DataTileMMAAttr::buildMmaOperation` creates a `MMAAttr` just to call `buildMmaOperation` on it, to reuse that implementation. In addition to being roundabout, this required a comment explaining why we discarded the error status, as `MMAAttr::buildMmaOperation` is fallible but here we were already past validation and mutating IR. This PR refactors that to let both call a shared, infallible helper. Signed-off-by: Benoit Jacob <[email protected]>
1 parent f828914 commit 2a2bd06

File tree

1 file changed

+66
-99
lines changed

1 file changed

+66
-99
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 66 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -220,88 +220,45 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
220220
Type i8 = IntegerType::get(context, 8);
221221
Type i32 = IntegerType::get(context, 32);
222222
switch (intrinsic) {
223-
case MMAIntrinsic::MFMA_F64_16x16x4_F64: {
223+
case MMAIntrinsic::MFMA_F64_16x16x4_F64:
224224
return {f64, f64, f64};
225-
}
226-
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
225+
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
227226
return {f32, f32, f32};
228-
}
229-
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
230-
return {f16, f16, f32};
231-
}
232-
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
227+
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
228+
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
229+
case MMAIntrinsic::WMMA_F32_16x16x16_F16:
230+
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16:
233231
return {f16, f16, f32};
234-
}
235-
case MMAIntrinsic::MFMA_F32_16x16x8_BF16: {
236-
return {bf16, bf16, f32};
237-
}
238-
case MMAIntrinsic::MFMA_F32_32x32x4_BF16: {
239-
return {bf16, bf16, f32};
240-
}
241-
case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
242-
return {bf16, bf16, f32};
243-
}
244-
case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
232+
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
233+
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16:
234+
return {f16, f16, f16};
235+
case MMAIntrinsic::MFMA_F32_16x16x8_BF16:
236+
case MMAIntrinsic::MFMA_F32_32x32x4_BF16:
237+
case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
238+
case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
239+
case MMAIntrinsic::WMMA_F32_16x16x16_BF16:
245240
return {bf16, bf16, f32};
246-
}
247-
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
248-
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
249-
}
250-
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ: {
251-
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
252-
}
253-
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ: {
254-
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
255-
}
256-
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
257-
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
258-
}
259-
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: {
241+
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16:
242+
return {bf16, bf16, bf16};
243+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
244+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
260245
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
261-
}
262-
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: {
246+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
247+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ:
263248
return {f8E5M2FNUZ, f8E5M2FNUZ, f32};
264-
}
265-
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: {
249+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ:
250+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ:
266251
return {f8E4M3FNUZ, f8E5M2FNUZ, f32};
267-
}
268-
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: {
252+
case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ:
253+
case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ:
269254
return {f8E5M2FNUZ, f8E4M3FNUZ, f32};
270-
}
271-
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
272-
return {i8, i8, i32};
273-
}
274-
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
275-
return {i8, i8, i32};
276-
}
277-
case MMAIntrinsic::MFMA_I32_32x32x8_I8: {
278-
return {i8, i8, i32};
279-
}
280-
case MMAIntrinsic::MFMA_I32_16x16x16_I8: {
281-
return {i8, i8, i32};
282-
}
283-
case MMAIntrinsic::WMMA_F32_16x16x16_F16: {
284-
return {f16, f16, f32};
285-
}
286-
case MMAIntrinsic::WMMA_F16_16x16x16_F16: {
287-
return {f16, f16, f16};
288-
}
289-
case MMAIntrinsic::WMMA_F32_16x16x16_BF16: {
290-
return {bf16, bf16, f32};
291-
}
292-
case MMAIntrinsic::WMMA_BF16_16x16x16_BF16: {
293-
return {bf16, bf16, bf16};
294-
}
295-
case MMAIntrinsic::WMMA_I32_16x16x16_I8: {
255+
case MMAIntrinsic::MFMA_I32_16x16x16_I8:
256+
case MMAIntrinsic::MFMA_I32_32x32x8_I8:
257+
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
258+
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
259+
case MMAIntrinsic::WMMA_I32_16x16x16_I8:
296260
return {i8, i8, i32};
297261
}
298-
case MMAIntrinsic::NV_WMMA_F16_16x16x16_F16: {
299-
return {f16, f16, f16};
300-
}
301-
case MMAIntrinsic::NV_WMMA_F32_16x16x16_F16: {
302-
return {f16, f16, f32};
303-
}
304-
}
305262
assert(false && "unexpected enum value");
306263
return {};
307264
}
@@ -498,11 +455,15 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
498455
return IREE::GPU::getContractionLayout(contract, layout);
499456
}
500457

501-
int64_t MMAAttr::getBlockSize() const {
458+
static int getBlockSize(MMAIntrinsic /*intrinsic*/) {
502459
// Not supporting any block size other than 1 at the moment.
503460
return 1;
504461
}
505462

463+
int64_t MMAAttr::getBlockSize() const {
464+
return IREE::GPU::getBlockSize(getIntrinsic().getValue());
465+
}
466+
506467
static uint32_t getArchID(MMAIntrinsic intrinsic) {
507468
return static_cast<int>(intrinsic) & 0xFF00;
508469
}
@@ -704,6 +665,31 @@ SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
704665
}
705666
}
706667

668+
static Value createMmaOp(OpBuilder &builder, Location loc,
669+
MMAIntrinsic intrinsic, Type resultType, Value lhs,
670+
Value rhs, Value acc) {
671+
auto getVecOrSingleElem = [&](Value vec) -> Value {
672+
bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
673+
return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
674+
};
675+
auto layout = getOpaqueMMALayout(builder.getContext(), intrinsic);
676+
if (is_AMD_MFMA(intrinsic)) {
677+
// MFMA intrinsics want single-element operands of element type, not vector.
678+
lhs = getVecOrSingleElem(lhs);
679+
rhs = getVecOrSingleElem(rhs);
680+
return builder
681+
.create<amdgpu::MFMAOp>(loc, resultType, layout.mSize, layout.nSize,
682+
layout.kSize, getBlockSize(intrinsic), lhs, rhs,
683+
acc)
684+
.getResult();
685+
}
686+
if (is_AMD_WMMA(intrinsic)) {
687+
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
688+
.getResult();
689+
}
690+
return {};
691+
}
692+
707693
// Generates amdgpu.mfma/wmma operation on the given inputs for this attribute
708694
// type.
709695
FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
@@ -719,23 +705,9 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
719705
if (cType != resultType) {
720706
return failure();
721707
}
722-
auto getVecOrSingleElem = [&](Value vec) -> Value {
723-
bool one = llvm::cast<VectorType>(vec.getType()).getNumElements() == 1;
724-
return one ? builder.create<vector::ExtractOp>(loc, vec, 0) : vec;
725-
};
726-
MMAIntrinsic intrinsic = getIntrinsic().getValue();
727-
if (is_AMD_MFMA(intrinsic)) {
728-
// MFMA intrinsics want single-element operands of element type, not vector.
729-
lhs = getVecOrSingleElem(lhs);
730-
rhs = getVecOrSingleElem(rhs);
731-
auto [m, n, k] = getMNKShape();
732-
return builder
733-
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
734-
rhs, acc)
735-
.getResult();
736-
} else if (is_AMD_WMMA(intrinsic)) {
737-
return builder.create<amdgpu::WMMAOp>(loc, resultType, lhs, rhs, acc)
738-
.getResult();
708+
if (Value value = createMmaOp(builder, loc, getIntrinsic().getValue(),
709+
resultType, lhs, rhs, acc)) {
710+
return value;
739711
}
740712
return failure();
741713
}
@@ -1168,23 +1140,18 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
11681140
SmallVector<Value> intrinsicsAcc =
11691141
distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle);
11701142

1171-
// Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation
1172-
// to create the target intrinsics.
1173-
auto intrinsicMma = MMAAttr::get(getContext(), getIntrinsic().getValue());
1174-
auto [intrinsicAType, intrinsicBType, intrinsicCType] =
1175-
intrinsicMma.getABCVectorTypes();
1143+
MMAIntrinsic intrinsic = getIntrinsic().getValue();
1144+
VectorType intrinCType =
1145+
getVectorType(builder.getContext(), intrinsic, MMAFragment::Acc);
11761146

11771147
// Loop over the 3 unroll_{m,n,k} dimensions to create the intrinsics.
11781148
for (int mu = 0; mu < getUnrollM(); ++mu) {
11791149
for (int nu = 0; nu < getUnrollN(); ++nu) {
11801150
for (int ku = 0; ku < getUnrollK(); ++ku) {
1181-
// Assume intrinsicMma.buildMmaOperation() success: validation should be
1182-
// completed prior to mutating IR.
11831151
Value lhs = intrinsicsLhs[mu * getUnrollK() + ku];
11841152
Value rhs = intrinsicsRhs[nu * getUnrollK() + ku];
11851153
Value &acc = intrinsicsAcc[mu * getUnrollN() + nu];
1186-
acc = *intrinsicMma.buildMmaOperation(builder, loc, intrinsicCType, lhs,
1187-
rhs, acc);
1154+
acc = createMmaOp(builder, loc, intrinsic, intrinCType, lhs, rhs, acc);
11881155
}
11891156
}
11901157
}

0 commit comments

Comments
 (0)