Skip to content
177 changes: 177 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/FMADotUtility.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
#include "triton/Conversion/TritonGPUToLLVM/FMADotUtility.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;

Expand Down Expand Up @@ -71,10 +77,57 @@ ValueTableFMA getValueTableFromStructFMA(
return res;
}

// Create an empty loop using the given lower bound \p lb, upper bound \p ub and
// step \p step. Return the body block of the created loop.
Block *createEmptyLoop(Value iv, Value ub, Value step,
ConversionPatternRewriter &rewriter, Location loc) {
MLIRContext *ctx = rewriter.getContext();
Block *insertionBlock = rewriter.getInsertionBlock();
Block *headerBlock =
rewriter.splitBlock(insertionBlock, rewriter.getInsertionPoint());
Block *bodyBlock = rewriter.splitBlock(headerBlock, headerBlock->begin());
Block *endBlock = rewriter.splitBlock(bodyBlock, bodyBlock->begin());
rewriter.setInsertionPointToEnd(insertionBlock);

// Loop header.
rewriter.create<cf::BranchOp>(loc, headerBlock, SmallVector<Value>{iv});
rewriter.setInsertionPointToStart(headerBlock);
auto b = TritonLLVMOpBuilder(loc, rewriter);
rewriter.create<cf::CondBranchOp>(loc, b.icmp_slt(iv, ub), bodyBlock,
endBlock, SmallVector<Value>{iv});
rewriter.setInsertionPointToStart(bodyBlock);

// Loop body.
auto nextIv = b.add(iv, step);
rewriter.create<cf::BranchOp>(loc, headerBlock, SmallVector<Value>{nextIv});
rewriter.setInsertionPointToStart(endBlock);

return bodyBlock;
}

// Initialize a variable to \p init and return the loaded value.
Value createIV(Value init, ConversionPatternRewriter &rewriter, Location loc) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto ptr = rewriter.create<LLVM::AllocaOp>(loc, ptr_ty(rewriter.getContext()),
init.getType(), b.i32_val(1));
rewriter.create<LLVM::StoreOp>(loc, init, ptr);
return rewriter.create<LLVM::LoadOp>(loc, init.getType(), ptr);
}

} // namespace

namespace mlir::triton::gpu {

enum class CodeGenMode {
Unroll,
Loop,
} codeGenMode = CodeGenMode::Unroll;

LogicalResult genFMALoop(DotOp, ValueTableFMA &, ValueTableFMA &,
ArrayRef<Value>, ArrayRef<unsigned>,
ArrayRef<unsigned>, unsigned, Type,
ConversionPatternRewriter &);

LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -134,6 +187,12 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,

SmallVector<Value> acc = cc;

if (codeGenMode == CodeGenMode::Loop) {
Type dType = typeConverter->convertType(dTensorTy);
return genFMALoop(op, has, hbs, acc, sizePerThread, repetitions, K, dType,
rewriter);
}

for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
Expand Down Expand Up @@ -167,4 +226,122 @@ LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
return success();
}

LogicalResult genFMALoop(DotOp op, ValueTableFMA &has, ValueTableFMA &hbs,
ArrayRef<Value> acc, ArrayRef<unsigned> sizePerThread,
ArrayRef<unsigned> repetitions, unsigned K, Type dType,
ConversionPatternRewriter &rewriter) {
ModuleOp mod = op->getParentOfType<ModuleOp>();
MLIRContext *ctx = rewriter.getContext();
Location loc = op.getLoc();

// Copy struct into vector for operand A.
SmallVector<Value> v1;
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b)
for (unsigned m = 0; m < sizePerThread[1]; ++m)
for (unsigned k = 0; k < K; ++k)
v1.push_back(has.at({bRep, mRep, b, m, k}));
Value vecA = packLLVector(loc, v1, rewriter);

// Copy struct into vector for operand B.
SmallVector<Value> v2;
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b)
for (unsigned n = 0; n < sizePerThread[2]; ++n)
for (unsigned k = 0; k < K; ++k)
v2.push_back(hbs.at({bRep, nRep, b, n, k}));
Value vecB = packLLVector(loc, v2, rewriter);

// Copy struct into vector for operand C.
Value vecC = packLLVector(loc, acc, rewriter);

const unsigned len = acc.size();
Type elemType = acc.front().getType();
auto builder = TritonLLVMOpBuilder(loc, rewriter);
Value vecD = builder.undef(vec_ty(elemType, len));

Value zero = builder.i32_val(0), one = builder.i32_val(1);
for (unsigned bRep = 0; bRep < repetitions[0]; ++bRep)
for (unsigned mRep = 0; mRep < repetitions[1]; ++mRep)
for (unsigned nRep = 0; nRep < repetitions[2]; ++nRep)
for (unsigned b = 0; b < sizePerThread[0]; ++b) {
// Generate the outer loop.
Value outerIV = createIV(zero, rewriter, loc);
Value outerUB = builder.i32_val(sizePerThread[1]);
Value outerStep = builder.i32_val(sizePerThread[2]);
Block *outerBody =
createEmptyLoop(outerIV, outerUB, outerStep, rewriter, loc);
auto afterOuterLoop = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(outerBody);

// Get the values for operand A.
SmallVector<Value> AElems;
for (unsigned i = 0; i < sizePerThread[2]; ++i) {
Value idx = builder.add(outerIV, builder.i32_val(i));
AElems.push_back(builder.extract_element(vecA, idx));
}

// Generate the inner loop.
Value innerIV = createIV(zero, rewriter, loc);
Value innerUB = outerStep;
Value innerStep = one;
Block *innerBody =
createEmptyLoop(outerIV, innerUB, innerStep, rewriter, loc);
rewriter.setInsertionPointToStart(innerBody);

// Get the values for operand B.
SmallVector<Value> BElems;
for (unsigned j = 0; j < sizePerThread[2]; ++j) {
Value idx =
builder.add(innerIV, builder.i32_val(sizePerThread[2] * j));
BElems.push_back(builder.extract_element(vecB, idx));
}

// Get the value for operand C.
// TODO: generate FMA for integer here.
Value accIdx = builder.fma(innerUB, outerIV, innerIV);
Value acc = builder.extract_element(vecC, accIdx);

// Perform the FMAs.
for (unsigned k = 0; k < sizePerThread[2]; ++k) {
TypeSwitch<Type>(elemType)
.Case<FloatType>([&](auto) {
acc = rewriter.create<LLVM::FMulAddOp>(loc, AElems[k],
BElems[k], acc);
})
.Case<IntegerType>([&](auto) {
acc = builder.fma(AElems[k], BElems[k], acc);
});
}

// Store the result.
builder.insert_element(vecD, acc, accIdx);
rewriter.restoreInsertionPoint(afterOuterLoop);
}

// Create a loop to copy vecD into a struct.
Value ub = builder.i32_val(len);
auto structPtr =
rewriter.create<LLVM::AllocaOp>(loc, ptr_ty(ctx), elemType, ub);
Value iv = createIV(zero, rewriter, loc);
Block *body = createEmptyLoop(iv, ub, one, rewriter, loc);
auto afterLoop = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(body);
Value val = builder.extract_element(vecD, iv);
Value ptr = builder.gep(ptr_ty(ctx), val.getType(), structPtr, iv);
rewriter.create<LLVM::StoreOp>(loc, val, ptr);
rewriter.restoreInsertionPoint(afterLoop);
auto loadVal = rewriter.create<LLVM::LoadOp>(loc, dType, structPtr);
rewriter.replaceOp(op, loadVal);

llvm::errs() << "at line: " << __LINE__ << "\n";
llvm::errs() << "Module after:\n";
mod->dumpPretty();
llvm::errs() << "\n";

Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug output statements should be removed from production code. These debugging prints will pollute the output in release builds.

Suggested change
llvm::errs() << "at line: " << __LINE__ << "\n";
llvm::errs() << "Module after:\n";
mod->dumpPretty();
llvm::errs() << "\n";

Copilot uses AI. Check for mistakes.
return success();
}

} // namespace mlir::triton::gpu
Loading