Skip to content

Commit fdd694d

Browse files
authored
[Triton][Gluon] Add map_elementwise (#7564)
Dug this one up from the archive. I originally wrote this so that you could use control flow in elementwise computation, so e.g. if computing a multi-branch function you don't need to compute all branches and combine them with `where`. For our purposes though, it also has the effect of changing the order we emit the llvm so that each element in a tensor is processes one at a time instead of all at once.
1 parent 03cdcdb commit fdd694d

File tree

15 files changed

+535
-44
lines changed

15 files changed

+535
-44
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,19 @@ Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
642642
const TargetInfoBase &targetInfo,
643643
const LLVMTypeConverter *typeConverter,
644644
RewriterBase &rewriter);
645+
646+
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
647+
ArrayRef<Value> args,
648+
mlir::TypeID terminatorTypeId,
649+
Location loc);
650+
651+
template <typename TerminatorOp>
652+
SmallVector<Value> inlineRegion(RewriterBase &rewriter, Region &region,
653+
ArrayRef<Value> args, Location loc) {
654+
return inlineRegionImpl(rewriter, region, args,
655+
mlir::TypeID::get<TerminatorOp>(), loc);
656+
}
657+
645658
} // namespace mlir
646659

647660
#endif

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,26 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
797797
let assemblyFormat = "$result attr-dict `:` type($result)";
798798
}
799799

800+
//
801+
// Map Elementwise op
802+
//
803+
def TT_MapElementwiseOp: TT_Op<"map_elementwise", [SameOperandsAndResultEncoding,
804+
SameOperandsAndResultShape,
805+
RecursiveMemoryEffects]> {
806+
let summary = "Map a scalar subregion over a tensor";
807+
let arguments = (ins Variadic<TT_Tensor>:$srcs, I32Attr:$pack);
808+
let results = (outs Variadic<TT_Tensor>:$result);
809+
let regions = (region AnyRegion:$scalarOp);
810+
let hasVerifier = 1;
811+
let hasRegionVerifier = 1;
812+
}
813+
814+
def TT_MapElementwiseReturnOp: TT_Op<"map_elementwise.return",
815+
[HasParent<"MapElementwiseOp">, Pure, Terminator, ReturnLike]> {
816+
let summary = "terminator for map elementwise operator";
817+
let arguments = (ins Variadic<AnyType>:$result);
818+
let assemblyFormat = "attr-dict ($result^ `:` type($result))?";
819+
}
800820

801821
//
802822
// External Elementwise op

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,77 @@ struct ClampFOpConversion
571571
const TargetInfoBase &targetInfo;
572572
};
573573

574+
struct MapElementwiseOpConversion
575+
: public ConvertOpToLLVMPattern<MapElementwiseOp> {
576+
using Base = ConvertOpToLLVMPattern<MapElementwiseOp>;
577+
using Adaptor = typename Base::OpAdaptor;
578+
579+
using Base::Base;
580+
581+
LogicalResult matchAndRewrite(MapElementwiseOp op, OpAdaptor adaptor,
582+
ConversionPatternRewriter &rewriter) const {
583+
Location loc = op->getLoc();
584+
auto typeConverter = getTypeConverter();
585+
586+
auto operands = adaptor.getOperands();
587+
const auto nOperands = operands.size();
588+
const auto nElems =
589+
cast<LLVM::LLVMStructType>(operands[0].getType()).getBody().size();
590+
const auto nElemsPerPack = op.getPack();
591+
if (nElems % nElemsPerPack != 0)
592+
return op->emitError()
593+
<< "pack size must be a divisor of the number of elements per "
594+
"thread, but got pack = "
595+
<< nElemsPerPack << ", elements per thread = " << nElems << "\n";
596+
597+
const auto nPacks = nElems / nElemsPerPack;
598+
auto nArgsUnpacked = nElemsPerPack * nOperands;
599+
600+
SmallVector<Value> scalarOperands(nOperands * nElems);
601+
for (auto iOp : llvm::seq(nOperands)) {
602+
auto elems = unpackLLElements(loc, operands[iOp], rewriter);
603+
assert(elems.size() == nElems);
604+
for (auto iPack : llvm::seq(nPacks)) {
605+
auto *packOperands =
606+
&scalarOperands[iPack * nArgsUnpacked + iOp * nElemsPerPack];
607+
auto *packElems = &elems[iPack * nElemsPerPack];
608+
for (auto iElem : llvm::seq(nElemsPerPack)) {
609+
packOperands[iElem] = packElems[iElem];
610+
}
611+
}
612+
}
613+
614+
auto &scalarOp = op.getScalarOp();
615+
Region &parent = *rewriter.getBlock()->getParent();
616+
617+
auto nOutputs = op.getNumResults();
618+
SmallVector<Value> scalarOutputs(nOutputs * nElems);
619+
for (auto iPack : llvm::seq(nPacks)) {
620+
ArrayRef<Value> packedArgs(&scalarOperands[iPack * nArgsUnpacked],
621+
nArgsUnpacked);
622+
auto packResults = inlineRegion<triton::MapElementwiseReturnOp>(
623+
rewriter, scalarOp, packedArgs, loc);
624+
assert(packResults.size() == nOutputs * nElemsPerPack);
625+
for (auto iOut : llvm::seq(nOutputs)) {
626+
auto *packOutputs =
627+
&scalarOutputs[iOut * nElems + iPack * nElemsPerPack];
628+
for (auto iElem : llvm::seq(nElemsPerPack)) {
629+
packOutputs[iElem] = packResults[iOut * nElemsPerPack + iElem];
630+
}
631+
}
632+
}
633+
634+
SmallVector<Value> packedOutputs(nOutputs);
635+
for (auto iOut : llvm::seq(nOutputs)) {
636+
ArrayRef<Value> vals(&scalarOutputs[iOut * nElems], nElems);
637+
packedOutputs[iOut] =
638+
packLLElements(loc, typeConverter, vals, rewriter, op.getType(iOut));
639+
}
640+
rewriter.replaceOp(op, packedOutputs);
641+
return success();
642+
}
643+
};
644+
574645
} // namespace
575646

576647
void mlir::triton::populateMinMaxFOpToLLVMPattern(
@@ -662,4 +733,5 @@ void mlir::triton::populateElementwiseOpToLLVMPatterns(
662733
patterns.add<AbsIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
663734
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
664735
patterns.add<SelectOpConversion>(typeConverter, axisInfoAnalysis, benefit);
736+
patterns.add<MapElementwiseOpConversion>(typeConverter, benefit);
665737
}

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,4 +1796,63 @@ Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
17961796
return result;
17971797
}
17981798

1799+
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
1800+
ArrayRef<Value> args,
1801+
mlir::TypeID terminatorTypeId,
1802+
Location loc) {
1803+
// Inline regions with multiple blocks
1804+
//
1805+
// Before After
1806+
// ┌─────────┐
1807+
// │ op1 │
1808+
// ┌──────────┐ │ cf.br │
1809+
// │region[0] │ └────┬────┘
1810+
// │cf.cond_br├─┐ ┌────▼─────┐
1811+
// └────┬─────┘ │ │region[0] │
1812+
// │ │ │cf.cond_br├─┐
1813+
// ┌───────┐ ┌────▼────┐ │ └────┬─────┘ │
1814+
// │ op1 │ IP │region[1]│ │ ┌────▼────┐ │
1815+
// │ │◄─── │yield ...│ │ │region[1]│ │
1816+
// │ op2 │ └─────────┘ │ ┌─┤cf.br │ │
1817+
// └───────┘ │ │ └─────────┘ │
1818+
// ┌─────────┐ │ │ ┌─────────┐ │
1819+
// │region[2]│◄─┘ │ │region[2]│◄─┘
1820+
// │yield │ │ │cf.br │
1821+
// └─────────┘ │ └────┬────┘
1822+
// │ ┌────▼────┐
1823+
// └►│op2 │
1824+
// └─────────┘
1825+
auto *curBlock = rewriter.getInsertionBlock();
1826+
auto opPosition = rewriter.getInsertionPoint();
1827+
auto *remainingOpsBlock = rewriter.splitBlock(curBlock, opPosition);
1828+
1829+
IRMapping regionMap;
1830+
Region &parent = *curBlock->getParent();
1831+
rewriter.cloneRegionBefore(region, parent, parent.end(), regionMap);
1832+
rewriter.setInsertionPointToEnd(curBlock);
1833+
rewriter.create<LLVM::BrOp>(loc, args, regionMap.lookup(&region.front()));
1834+
1835+
ValueRange terminatorOperands;
1836+
for (Block &origBlock : region) {
1837+
Block *newBlock = regionMap.lookup(&origBlock);
1838+
rewriter.moveBlockBefore(newBlock, remainingOpsBlock);
1839+
1840+
auto terminator = newBlock->getTerminator();
1841+
if (terminator->getRegisteredInfo()->getTypeID() == terminatorTypeId) {
1842+
terminatorOperands = terminator->getOperands();
1843+
rewriter.setInsertionPointAfter(terminator);
1844+
rewriter.replaceOpWithNewOp<LLVM::BrOp>(terminator, terminatorOperands,
1845+
remainingOpsBlock);
1846+
}
1847+
}
1848+
1849+
rewriter.setInsertionPointToStart(remainingOpsBlock);
1850+
SmallVector<Value> vals;
1851+
for (auto resultTy : terminatorOperands.getType()) {
1852+
auto val = remainingOpsBlock->addArgument(resultTy, loc);
1853+
vals.push_back(val);
1854+
}
1855+
return vals;
1856+
}
1857+
17991858
} // namespace mlir

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,32 @@ struct TritonScanPattern : public OpConversionPattern<triton::ScanOp> {
466466
}
467467
};
468468

469+
struct TritonMapElementwisePattern
470+
: public OpConversionPattern<triton::MapElementwiseOp> {
471+
using OpConversionPattern::OpConversionPattern;
472+
473+
LogicalResult
474+
matchAndRewrite(triton::MapElementwiseOp op, OpAdaptor adaptor,
475+
ConversionPatternRewriter &rewriter) const override {
476+
auto converter = getTypeConverter();
477+
SmallVector<Type> resultTys;
478+
auto err = converter->convertTypes(op.getResults().getType(), resultTys);
479+
if (failed(err)) {
480+
return err;
481+
}
482+
483+
auto newMapOp = rewriter.create<triton::MapElementwiseOp>(
484+
op.getLoc(), resultTys, adaptor.getOperands(), op.getPack());
485+
addNamedAttrs(newMapOp, adaptor.getAttributes());
486+
487+
auto &newScalarOp = newMapOp.getScalarOp();
488+
rewriter.cloneRegionBefore(op.getScalarOp(), newScalarOp,
489+
newScalarOp.end());
490+
rewriter.replaceOp(op, newMapOp.getResult());
491+
return success();
492+
}
493+
};
494+
469495
class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
470496
public:
471497
using OpConversionPattern::OpConversionPattern;
@@ -548,6 +574,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
548574
TritonExpandDimsPattern,
549575
TritonTransPattern,
550576
TritonDotPattern,
577+
TritonMapElementwisePattern,
551578
GatherScatterOpPattern<DescriptorGatherOp>,
552579
GatherScatterOpPattern<DescriptorScatterOp>,
553580
GenericOpPattern<triton::LoadOp>,

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "triton/Dialect/Triton/IR/Types.h"
1010
#include "triton/Dialect/Triton/IR/Utility.h"
1111
#include "llvm/Support/ErrorHandling.h"
12+
#include "llvm/Support/MathExtras.h"
1213

1314
namespace mlir {
1415
namespace triton {
@@ -444,16 +445,9 @@ template <class Op> LogicalResult verifyReduceScan(Op &op) {
444445
return op.emitOpError() << "must have the same number of inputs as outputs";
445446
}
446447

447-
auto getElementType = [](Type ty) {
448-
if (auto tensorType = dyn_cast<RankedTensorType>(ty)) {
449-
return tensorType.getElementType();
450-
}
451-
return ty;
452-
};
453-
454448
for (auto [opElemTy, resTy] :
455449
llvm::zip(op.getElementTypes(), op.getResultTypes())) {
456-
if (opElemTy != getElementType(resTy)) {
450+
if (opElemTy != getElementTypeOrSelf(resTy)) {
457451
return op.emitOpError() << "operand types and result types must agree";
458452
}
459453
}
@@ -517,8 +511,8 @@ getInputTypesImpl(const Operation::operand_range &operands) {
517511
return srcTys;
518512
}
519513

520-
static llvm::SmallVector<Type>
521-
getElementTypesImpl(const Operation::operand_range &operands) {
514+
template <typename ValueRange>
515+
static llvm::SmallVector<Type> getElementTypesImpl(const ValueRange &operands) {
522516
llvm::SmallVector<Type> srcElemTys;
523517
srcElemTys.reserve(operands.size());
524518
for (const auto &op : operands) {
@@ -594,6 +588,59 @@ llvm::SmallVector<Type> ScanOp::getElementTypes() {
594588

595589
unsigned ScanOp::getNumOperands() { return this->getOperands().size(); }
596590

591+
//-- MapElementwiseOp
592+
LogicalResult MapElementwiseOp::verify() {
593+
if (getOperands().empty()) {
594+
return emitOpError() << "MapElementwiseOp must have at least 1 operand";
595+
}
596+
if (!llvm::isPowerOf2_32(getPack())) {
597+
return emitOpError() << "Pack must be a power of 2";
598+
}
599+
return success();
600+
}
601+
602+
template <typename T>
603+
SmallVector<T> repeatInterleave(const SmallVectorImpl<T> &vs, int nRepeat) {
604+
SmallVector<T> result;
605+
result.reserve(vs.size() * nRepeat);
606+
for (auto v : vs)
607+
for (auto _ : llvm::seq(nRepeat))
608+
result.push_back(v);
609+
return result;
610+
}
611+
612+
LogicalResult MapElementwiseOp::verifyRegions() {
613+
// Verify signature
614+
auto *firstBlock = &getRegion().getBlocks().front();
615+
if (firstBlock->getNumArguments() != getNumOperands() * getPack()) {
616+
return emitOpError() << "region has wrong number of arguments";
617+
}
618+
619+
auto expectedArgTypes =
620+
repeatInterleave(getElementTypesImpl(getOperands()), getPack());
621+
if (firstBlock->getArgumentTypes() != expectedArgTypes) {
622+
return emitError() << "argument types did not match";
623+
}
624+
auto expectedReturnTypes =
625+
repeatInterleave(getElementTypesImpl(getResults()), getPack());
626+
auto walkRes = getRegion().walk([&](Operation *op) -> WalkResult {
627+
auto memEffects = dyn_cast<MemoryEffectOpInterface>(op);
628+
// Ban stores as we won't get the redundant masking correct by treating it
629+
// as a scalar.
630+
if (memEffects && memEffects.hasEffect<MemoryEffects::Write>()) {
631+
return op->emitOpError()
632+
<< "Stores are not supported inside map_elementwise";
633+
}
634+
if (isa<MapElementwiseReturnOp>(op) &&
635+
op->getOperandTypes() != expectedReturnTypes) {
636+
return op->emitError()
637+
<< "region return does not match map_elementwise result";
638+
}
639+
return WalkResult::advance();
640+
});
641+
return success(!walkRes.wasInterrupted());
642+
}
643+
597644
//-- SplatOp --
598645
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
599646
auto value = adaptor.getSrc();

python/src/ir.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,15 @@ void init_triton_ir(py::module &&m) {
16501650
}
16511651
return self.create<ScanReturnOp>(return_values);
16521652
})
1653+
.def("create_map_elementwise",
1654+
[](TritonOpBuilder &self, std::vector<Value> inputs,
1655+
std::vector<Type> returnTys, int pack) -> OpState {
1656+
return self.create<MapElementwiseOp>(returnTys, inputs, pack);
1657+
})
1658+
.def("create_map_elementwise_ret",
1659+
[](TritonOpBuilder &self, std::vector<Value> returnVals) -> OpState {
1660+
return self.create<MapElementwiseReturnOp>(returnVals);
1661+
})
16531662
.def("create_ptr_to_int",
16541663
[](TritonOpBuilder &self, Value &val, Type &type) -> Value {
16551664
return self.create<PtrToIntOp>(type, val);

0 commit comments

Comments
 (0)