Skip to content

Commit 3ddee06

Browse files
authored
[AMD] NFC: Refactor DotOpMFMAConversionHelper (triton-lang#5862)
This PR refactored `DotOpMFMAConversionHelper` by extracting utility functions from `convertDot` to make it easier to be extended in triton-lang#5845.
1 parent e4e6687 commit 3ddee06

File tree

1 file changed

+66
-56
lines changed
  • third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM

1 file changed

+66
-56
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 66 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,67 @@ struct DotOpMFMAConversionHelper {
165165
return processSubBlocks(numSubBlocks, acc, false, true);
166166
}
167167

168+
/// Dot operand layout minimal tile is kDimInstrSize elements across
169+
/// K dimension. If dot operand K dimension is smaller, layout
170+
/// assigns tensor elements to multiple different hardware locations.
171+
/// In this case mfma instruction adds elements in accumulator
172+
/// multiple times.
173+
///
174+
/// Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11
175+
/// Consider instruction K size is 4,
176+
/// in this case operands will be duplicated:
177+
/// A' = [1,2,1,2] B' = [3,4,3,4]
178+
/// C' = (1*3+2*4) + (1*3+2*4) = 22
179+
///
180+
/// Following code adjusts accumulator values in such cases.
181+
/// If accumulator is integer, shift accumulator right by
182+
/// log2(duplicationRate). If accumulator is float, multiply accum
183+
/// with 1/duplicationRate constant.
184+
void adjustAccForSmallKDim(SmallVector<Value> &fc, Value &acc, Type dstElemTy,
185+
int b, int m, int n, int64_t numRepM,
186+
int64_t numRepN, int64_t kDimInstrSize,
187+
int64_t kDimOperandSize,
188+
unsigned elemsPerVec) const {
189+
auto tb = TritonLLVMOpBuilder(loc, rewriter);
190+
for (unsigned v = 0; v < elemsPerVec; ++v) {
191+
Value accElem = tb.extract_element(dstElemTy, acc, tb.i32_val(v));
192+
if (kDimInstrSize > kDimOperandSize) {
193+
assert(kDimInstrSize % kDimOperandSize == 0);
194+
int duplicationRate = kDimInstrSize / kDimOperandSize;
195+
assert(llvm::isPowerOf2_32(duplicationRate));
196+
if (dstElemTy.isInteger()) {
197+
auto shiftSize = llvm::Log2_32(duplicationRate);
198+
assert(!accElem.getType().isUnsignedInteger() &&
199+
"MFMA uses signed accumulator");
200+
accElem = tb.ashr(accElem, tb.i32_val(shiftSize));
201+
} else {
202+
auto multiplierAttr =
203+
rewriter.getFloatAttr(dstElemTy, 1.0 / duplicationRate);
204+
auto multiplierVal =
205+
rewriter.create<LLVM::ConstantOp>(loc, dstElemTy, multiplierAttr);
206+
accElem = tb.fmul(accElem, multiplierVal);
207+
}
208+
}
209+
auto linearIdx = b * numRepM * numRepN * elemsPerVec +
210+
m * numRepN * elemsPerVec + n * elemsPerVec + v;
211+
fc[linearIdx] = accElem;
212+
}
213+
}
214+
215+
void packAndReplaceResult(DotOp &op, SmallVector<Value> &fc,
216+
FailureOr<MfmaInsn> maybeMfmaInsn, Type dstElemTy,
217+
Type elemtTy, size_t mmaCount) const {
218+
Type structTy = LLVM::LLVMStructType::getLiteral(
219+
ctx, SmallVector<Type>(fc.size(), dstElemTy));
220+
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);
221+
222+
setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(),
223+
maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(),
224+
elemtTy);
225+
226+
rewriter.replaceOp(op, res);
227+
}
228+
168229
// Conduct the Dot conversion.
169230
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
170231
auto tb = TritonLLVMOpBuilder(loc, rewriter);
@@ -243,11 +304,6 @@ struct DotOpMFMAConversionHelper {
243304
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;
244305

245306
Value firstMfma;
246-
auto setFirstMfma = [&](Value mfma) {
247-
if (!firstMfma)
248-
firstMfma = mfma;
249-
};
250-
251307
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
252308
for (int b = 0; b < numRepB; ++b) {
253309
for (int m = 0; m < numRepM; ++m) {
@@ -269,49 +325,13 @@ struct DotOpMFMAConversionHelper {
269325
operandA[kPack][{b, m, k}], acc)
270326
: generateMFMAOp(mfmaInsnName, operandA[kPack][{b, m, k}],
271327
operandB[kPack][{b, n, k}], acc);
272-
setFirstMfma(acc);
328+
if (!firstMfma)
329+
firstMfma = acc;
273330
}
274331
}
275332
acc = reduceSubBlocks(subBlocks, acc);
276-
for (unsigned v = 0; v < elemsPerVec; ++v) {
277-
Value accElem = tb.extract_element(dstElemTy, acc, tb.i32_val(v));
278-
// Dot operand layout minimal tile is kDimInstrSize elements across
279-
// K dimension. If dot operand K dimension is smaller, layout
280-
// assigns tensor elements to multiple different hardware locations.
281-
// In this case mfma instruction adds elements in accumulator
282-
// multiple times.
283-
//
284-
// Let say A=[1,2]; B=[3,4], C = A*B = 1*3+2*4 = 11
285-
// Consider instruction K size is 4,
286-
// in this case operands will be duplicated:
287-
// A' = [1,2,1,2] B' = [3,4,3,4]
288-
// C' = (1*3+2*4) + (1*3+2*4) = 22
289-
//
290-
// Following code adjusts accumulator values in such cases.
291-
// If accumulator is integer, shift accumulator right by
292-
// log2(duplicationRate). If accumulator is float, multiply accum
293-
// with 1/duplicationRate constant.
294-
if (kDimInstrSize > kDimOperandSize) {
295-
assert(kDimInstrSize % kDimOperandSize == 0);
296-
int duplicationRate = kDimInstrSize / kDimOperandSize;
297-
assert(llvm::isPowerOf2_32(duplicationRate));
298-
if (dstElemTy.isInteger()) {
299-
auto shiftSize = llvm::Log2_32(duplicationRate);
300-
assert(!accElem.getType().isUnsignedInteger() &&
301-
"MFMA uses signed accumulator");
302-
accElem = tb.ashr(accElem, tb.i32_val(shiftSize));
303-
} else {
304-
auto multiplierAttr =
305-
rewriter.getFloatAttr(dstElemTy, 1.0 / duplicationRate);
306-
auto multiplierVal = rewriter.create<LLVM::ConstantOp>(
307-
loc, dstElemTy, multiplierAttr);
308-
accElem = tb.fmul(accElem, multiplierVal);
309-
}
310-
}
311-
auto linearIdx = b * numRepM * numRepN * elemsPerVec +
312-
m * numRepN * elemsPerVec + n * elemsPerVec + v;
313-
fc[linearIdx] = accElem;
314-
}
333+
adjustAccForSmallKDim(fc, acc, dstElemTy, b, m, n, numRepM, numRepN,
334+
kDimInstrSize, kDimOperandSize, elemsPerVec);
315335
}
316336
}
317337
}
@@ -325,19 +345,9 @@ struct DotOpMFMAConversionHelper {
325345
if (setPrioOp && firstMfma)
326346
setPrioOp->moveAfter(firstMfma.getDefiningOp());
327347

328-
// replace with new packed result
329-
Type structTy = LLVM::LLVMStructType::getLiteral(
330-
ctx, SmallVector<Type>(fc.size(), dstElemTy));
331-
Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy);
332-
333-
Type elemtTy = elemTyA;
334348
const size_t mmaCount =
335349
numRepB * numRepM * numRepN * numRepK * kWidth / kBase;
336-
setNumGeneratedMMAs(op, mmaCount, maybeMfmaInsn->getMDim(),
337-
maybeMfmaInsn->getNDim(), maybeMfmaInsn->getKDim(),
338-
elemtTy);
339-
340-
rewriter.replaceOp(op, res);
350+
packAndReplaceResult(op, fc, maybeMfmaInsn, dstElemTy, elemTyA, mmaCount);
341351

342352
return success();
343353
}

0 commit comments

Comments
 (0)