Skip to content

Commit b4fe355

Browse files
committed
Merge commit '620c59165a1452ddd5dd685054b84485a35dc92e'
2 parents f8ce49d + 620c591 commit b4fe355

File tree

57 files changed

+1726
-2006
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1726
-2006
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class TargetInfoBase {
2525
std::optional<Value> ctaId, Value val,
2626
Value pred) const = 0;
2727
virtual Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
28-
std::optional<Value> ctaId, Type elemTy,
29-
Value pred) const = 0;
28+
std::optional<Value> ctaId, Type elemTy, Value pred,
29+
Operation *localLoadOp = nullptr) const = 0;
3030

3131
void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val,
3232
Value pred) const {

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -515,10 +515,13 @@ SmallVector<SmallVector<Value>>
515515
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
516516
Attribute layout, RankedTensorType type, bool withCTAOffset);
517517

518-
// Emits the required padding in elements for the given shared memory offset
518+
// Emits the required padding given shared memory offset
519+
// - If `offsetInBytes` is true, smemOffset and padding is assumed in bytes.
520+
// - If false, smemOffset and padding are assumed to be scaled by element
521+
// bitwidth, in which case, `bitwidth` is not used.
519522
Value emitPadding(Location loc, RewriterBase &rewriter,
520523
triton::gpu::PaddedSharedEncodingAttr layout,
521-
Value smemOffset);
524+
unsigned bitwidth, Value smemOffset, bool offsetInBytes);
522525

523526
// Emits IR to load data from shared memory into registers, or to store data
524527
// from registers into shared memory.
@@ -546,39 +549,33 @@ Value emitPadding(Location loc, RewriterBase &rewriter,
546549
Value laneId, Value warpId,
547550
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback);
548551

549-
SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
550-
Type elemLlvmTy,
551-
const SharedMemoryObject &smemObj,
552-
Location loc, RewriterBase &rewriter,
553-
const TargetInfoBase &target);
554-
555-
void storeDistributedToShared(triton::gpu::MemDescType dstTy,
556-
RankedTensorType srcTy, Type elemLlvmTy,
557-
ArrayRef<Value> srcVals,
558-
const SharedMemoryObject &smemObj, Location loc,
559-
RewriterBase &rewriter,
560-
const TargetInfoBase &target);
561-
562552
// Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
563553
// We might want to merge them at some point, but having to support
564554
// ldmatrix.trans makes the code in lowerLdStMatrix a bit specific
565555
// Lowers to st when valArrays is empty, and to ld when it is not,
566556
// and returns the output values.
567-
SmallVector<Value>
568-
lowerLdStShared(Location loc, MLIRContext *ctx, LinearLayout cvt,
569-
ArrayRef<Value> valsArray, // Input for store, output for load
570-
Type llvmElemTy, Value smemBase, Value affineOffset,
571-
uint64_t maskSpanAffineOffset,
572-
ConversionPatternRewriter &rewriter,
573-
const TargetInfoBase &targetInfo);
557+
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
558+
// and computes a new offset (mlir::Value) by applying padding based on
559+
// shared memory layout.
560+
SmallVector<Value> lowerLdStShared(
561+
Location loc, MLIRContext *ctx, LinearLayout cvt,
562+
ArrayRef<Value> valsArray, // Input for store, output for load
563+
Type llvmElemTy, Value smemBase,
564+
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
565+
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
566+
const TargetInfoBase &targetInfo, Operation *localLoadOp = nullptr);
574567

575568
// Lower an ld/st-like operation given a layout and a callback that creates the
576569
// PTX instruction Lowers to st when valArrays is empty, and to ld when it is
577570
// not, and returns the output values.
571+
// calcPaddedOffset is a lambda that takes a base offset (mlir::Value)
572+
// and computes a new offset (mlir::Value) by applying padding based on
573+
// shared memory layout.
578574
SmallVector<Value> lowerLdSt(
579575
Location loc, MLIRContext *ctx, LinearLayout cvt,
580576
ArrayRef<Value> valsArray, // Input for store, output for load
581-
Type llvmElemTy, Value smemBase, Value affineOffset,
577+
Type llvmElemTy, Value smemBase,
578+
std::function<Value(Value)> calcPaddedOffset, Value affineOffset,
582579
uint64_t maskSpanAffineOffset, ConversionPatternRewriter &rewriter,
583580
const TargetInfoBase &targetInfo, std::optional<int> maybeMaxVecElems,
584581
std::function<SmallVector<Value>(ConversionPatternRewriter &, Location,
@@ -592,7 +589,8 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
592589
ArrayRef<Value> valsArray, // Input for store, empty for load
593590
Type llvmElemTy, triton::gpu::MemDescType srcTy,
594591
SharedMemoryObject smemObj, ConversionPatternRewriter &rewriter,
595-
const TargetInfoBase &targetInfo);
592+
const TargetInfoBase &targetInfo,
593+
Operation *localLoadOp = nullptr);
596594

597595
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
598596
RewriterBase &rewriter);
@@ -644,6 +642,19 @@ Value transferWithinBlockPadding(triton::gpu::ConvertLayoutOp op, Value src,
644642
const TargetInfoBase &targetInfo,
645643
const LLVMTypeConverter *typeConverter,
646644
RewriterBase &rewriter);
645+
646+
SmallVector<Value> inlineRegionImpl(RewriterBase &rewriter, Region &region,
647+
ArrayRef<Value> args,
648+
mlir::TypeID terminatorTypeId,
649+
Location loc);
650+
651+
template <typename TerminatorOp>
652+
SmallVector<Value> inlineRegion(RewriterBase &rewriter, Region &region,
653+
ArrayRef<Value> args, Location loc) {
654+
return inlineRegionImpl(rewriter, region, args,
655+
mlir::TypeID::get<TerminatorOp>(), loc);
656+
}
657+
647658
} // namespace mlir
648659

649660
#endif

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,26 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
797797
let assemblyFormat = "$result attr-dict `:` type($result)";
798798
}
799799

800+
//
801+
// Map Elementwise op
802+
//
803+
def TT_MapElementwiseOp: TT_Op<"map_elementwise", [SameOperandsAndResultEncoding,
804+
SameOperandsAndResultShape,
805+
RecursiveMemoryEffects]> {
806+
let summary = "Map a scalar subregion over a tensor";
807+
let arguments = (ins Variadic<TT_Tensor>:$srcs, I32Attr:$pack);
808+
let results = (outs Variadic<TT_Tensor>:$result);
809+
let regions = (region AnyRegion:$scalarOp);
810+
let hasVerifier = 1;
811+
let hasRegionVerifier = 1;
812+
}
813+
814+
def TT_MapElementwiseReturnOp: TT_Op<"map_elementwise.return",
815+
[HasParent<"MapElementwiseOp">, Pure, Terminator, ReturnLike]> {
816+
let summary = "terminator for map elementwise operator";
817+
let arguments = (ins Variadic<AnyType>:$result);
818+
let assemblyFormat = "attr-dict ($result^ `:` type($result))?";
819+
}
800820

801821
//
802822
// External Elementwise op

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ When vec=2, elements are swizzled in pairs of 2. In other words, the element at
370370
let genVerifyDecl = 1;
371371
}
372372

373-
def PaddeddSharedEncodingAttr
373+
def PaddedSharedEncodingAttr
374374
: TritonGPU_Attr<"PaddedSharedEncoding", "padded_shared_encoding",
375375
[SharedEncodingTrait, LayoutEncodingTrait]> {
376376
let mnemonic = "padded_shared";

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
201201

202202
assert(permutedInVals.size() == tileSize * nReps);
203203
SmallVector<Value> outVals;
204+
auto noPaddingOffset = [](Value v) { return v; };
204205
auto affineOffset = b.i32_val(0);
205206
auto maskSpanAffineOffset = 0;
206207
for (int i = 0; i < nReps; ++i) {
@@ -211,12 +212,13 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
211212
ArrayRef<Value>(permutedInVals).slice(i * tileSize, tileSize);
212213
// Store
213214
lowerLdStShared(loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
214-
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
215+
noPaddingOffset, affineOffset, maskSpanAffineOffset,
216+
rewriter, targetInfo);
215217
b.barrier();
216218
// Load
217219
SmallVector<Value> tileOutVals = lowerLdStShared(
218-
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, affineOffset,
219-
maskSpanAffineOffset, rewriter, targetInfo);
220+
loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
221+
affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
220222
llvm::append_range(outVals, tileOutVals);
221223
}
222224

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,77 @@ struct ClampFOpConversion
571571
const TargetInfoBase &targetInfo;
572572
};
573573

574+
struct MapElementwiseOpConversion
575+
: public ConvertOpToLLVMPattern<MapElementwiseOp> {
576+
using Base = ConvertOpToLLVMPattern<MapElementwiseOp>;
577+
using Adaptor = typename Base::OpAdaptor;
578+
579+
using Base::Base;
580+
581+
LogicalResult matchAndRewrite(MapElementwiseOp op, OpAdaptor adaptor,
582+
ConversionPatternRewriter &rewriter) const {
583+
Location loc = op->getLoc();
584+
auto typeConverter = getTypeConverter();
585+
586+
auto operands = adaptor.getOperands();
587+
const auto nOperands = operands.size();
588+
const auto nElems =
589+
cast<LLVM::LLVMStructType>(operands[0].getType()).getBody().size();
590+
const auto nElemsPerPack = op.getPack();
591+
if (nElems % nElemsPerPack != 0)
592+
return op->emitError()
593+
<< "pack size must be a divisor of the number of elements per "
594+
"thread, but got pack = "
595+
<< nElemsPerPack << ", elements per thread = " << nElems << "\n";
596+
597+
const auto nPacks = nElems / nElemsPerPack;
598+
auto nArgsUnpacked = nElemsPerPack * nOperands;
599+
600+
SmallVector<Value> scalarOperands(nOperands * nElems);
601+
for (auto iOp : llvm::seq(nOperands)) {
602+
auto elems = unpackLLElements(loc, operands[iOp], rewriter);
603+
assert(elems.size() == nElems);
604+
for (auto iPack : llvm::seq(nPacks)) {
605+
auto *packOperands =
606+
&scalarOperands[iPack * nArgsUnpacked + iOp * nElemsPerPack];
607+
auto *packElems = &elems[iPack * nElemsPerPack];
608+
for (auto iElem : llvm::seq(nElemsPerPack)) {
609+
packOperands[iElem] = packElems[iElem];
610+
}
611+
}
612+
}
613+
614+
auto &scalarOp = op.getScalarOp();
615+
Region &parent = *rewriter.getBlock()->getParent();
616+
617+
auto nOutputs = op.getNumResults();
618+
SmallVector<Value> scalarOutputs(nOutputs * nElems);
619+
for (auto iPack : llvm::seq(nPacks)) {
620+
ArrayRef<Value> packedArgs(&scalarOperands[iPack * nArgsUnpacked],
621+
nArgsUnpacked);
622+
auto packResults = inlineRegion<triton::MapElementwiseReturnOp>(
623+
rewriter, scalarOp, packedArgs, loc);
624+
assert(packResults.size() == nOutputs * nElemsPerPack);
625+
for (auto iOut : llvm::seq(nOutputs)) {
626+
auto *packOutputs =
627+
&scalarOutputs[iOut * nElems + iPack * nElemsPerPack];
628+
for (auto iElem : llvm::seq(nElemsPerPack)) {
629+
packOutputs[iElem] = packResults[iOut * nElemsPerPack + iElem];
630+
}
631+
}
632+
}
633+
634+
SmallVector<Value> packedOutputs(nOutputs);
635+
for (auto iOut : llvm::seq(nOutputs)) {
636+
ArrayRef<Value> vals(&scalarOutputs[iOut * nElems], nElems);
637+
packedOutputs[iOut] =
638+
packLLElements(loc, typeConverter, vals, rewriter, op.getType(iOut));
639+
}
640+
rewriter.replaceOp(op, packedOutputs);
641+
return success();
642+
}
643+
};
644+
574645
} // namespace
575646

576647
void mlir::triton::populateMinMaxFOpToLLVMPattern(
@@ -662,4 +733,5 @@ void mlir::triton::populateElementwiseOpToLLVMPatterns(
662733
patterns.add<AbsIOpConversion>(typeConverter, axisInfoAnalysis, benefit);
663734
patterns.add<AbsFOpConversion>(typeConverter, axisInfoAnalysis, benefit);
664735
patterns.add<SelectOpConversion>(typeConverter, axisInfoAnalysis, benefit);
736+
patterns.add<MapElementwiseOpConversion>(typeConverter, benefit);
665737
}

0 commit comments

Comments
 (0)