Skip to content

Commit 61d3ac6

Browse files
authored
Move definitions from Conversion/TritonGPUToLLVM/Utility.h to a source file (#6831)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 1572e34 commit 61d3ac6

File tree

2 files changed

+293
-248
lines changed

2 files changed

+293
-248
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 24 additions & 248 deletions
Original file line numberDiff line numberDiff line change
@@ -338,46 +338,21 @@ using namespace mlir::triton;
338338

339339
class SharedMemoryObject {
340340
public:
341-
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> offsets)
342-
: base(base), baseElemType(baseElemType),
343-
offsets(offsets.begin(), offsets.end()) {}
341+
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> offsets);
344342

345343
SharedMemoryObject(Value base, Type baseElemType, int64_t rank, Location loc,
346-
RewriterBase &rewriter)
347-
: base(base), baseElemType(baseElemType) {
348-
auto b = TritonLLVMOpBuilder(loc, rewriter);
349-
offsets.append(rank, b.i32_val(0));
350-
}
344+
RewriterBase &rewriter);
351345

352346
SmallVector<Value> getOffsets() const { return offsets; }
353347
Value getBase() const { return base; }
354348
Type getBaseElemType() const { return baseElemType; }
355349

356-
SmallVector<Value> getElems() const {
357-
SmallVector<Value> elems;
358-
elems.push_back(base);
359-
elems.append(offsets.begin(), offsets.end());
360-
return elems;
361-
}
350+
SmallVector<Value> getElems() const;
362351

363-
SmallVector<Type> getTypes() const {
364-
SmallVector<Type> types;
365-
types.push_back(base.getType());
366-
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
367-
return types;
368-
}
352+
SmallVector<Type> getTypes() const;
369353

370354
SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
371-
RewriterBase &rewriter) const {
372-
auto allocShape = memDesc.getAllocShape();
373-
auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA(
374-
memDesc.getEncoding(), allocShape);
375-
auto layoutOrder = triton::gpu::getOrder(memDesc);
376-
auto allocStrides = SharedMemoryObject::getStridesForShape(
377-
allocShapePerCTA, layoutOrder, loc, rewriter);
378-
return SmallVector<Value>(allocStrides.end() - offsets.size(),
379-
allocStrides.end());
380-
}
355+
RewriterBase &rewriter) const;
381356

382357
// TODO(Keren): deprecate the method once AMD backend has cleaned up
383358
Value getCSwizzleOffset(int dim) const {
@@ -386,50 +361,16 @@ class SharedMemoryObject {
386361
}
387362

388363
// TODO(Keren): deprecate the method once AMD backend has cleaned up
389-
Value getBaseBeforeSlice(int dim, Location loc,
390-
RewriterBase &rewriter) const {
391-
auto b = TritonLLVMOpBuilder(loc, rewriter);
392-
Value cSwizzleOffset = getCSwizzleOffset(dim);
393-
Value offset = b.sub(b.i32_val(0), cSwizzleOffset);
394-
Type type = base.getType();
395-
return b.gep(type, baseElemType, base, offset);
396-
}
364+
Value getBaseBeforeSlice(int dim, Location loc, RewriterBase &rewriter) const;
397365

398366
private:
399-
static SmallVector<unsigned>
400-
getOrderForShape(ArrayRef<int64_t> shape, ArrayRef<unsigned> layoutOrder) {
401-
SmallVector<unsigned> order(shape.size());
402-
// Default minor-to-major order
403-
std::iota(order.rbegin(), order.rend(), 0);
404-
if (layoutOrder.size() > 0) {
405-
// If a layout order is provided, we assume it specifies the order in
406-
// which the dimensions are first accessed, and unspecified dimensions
407-
// retain the minor-to-major order. For example, if order = [2, 1, 0] and
408-
// layoutOrder = [0, 1], we need to shift `layoutOrder`
409-
// by -1 (move them right). The resulting order will then be [1, 2, 0].
410-
int rankDiff = layoutOrder.size() - shape.size();
411-
auto minRank = std::min<size_t>(shape.size(), layoutOrder.size());
412-
for (size_t i = 0; i < minRank; ++i)
413-
order[i] = layoutOrder[i] - rankDiff;
414-
}
415-
assert(isPermutationOfIota(order) && "Invalid order");
416-
return order;
417-
}
367+
static SmallVector<unsigned> getOrderForShape(ArrayRef<int64_t> shape,
368+
ArrayRef<unsigned> layoutOrder);
418369

419370
static SmallVector<Value> getStridesForShape(ArrayRef<int64_t> shape,
420371
ArrayRef<unsigned> layoutOrder,
421372
Location loc,
422-
RewriterBase &rewriter) {
423-
SmallVector<Value> strides(shape.size());
424-
auto order = SharedMemoryObject::getOrderForShape(shape, layoutOrder);
425-
int64_t stride = 1;
426-
auto b = TritonLLVMOpBuilder(loc, rewriter);
427-
for (auto idx : order) {
428-
strides[idx] = b.i32_val(stride);
429-
stride *= shape[idx];
430-
}
431-
return strides;
432-
}
373+
RewriterBase &rewriter);
433374

434375
Value base; // i32 ptr. The start address of the shared memory object.
435376
Type baseElemType;
@@ -486,97 +427,14 @@ inline bool isKernel(FunctionOpInterface funcOp) {
486427
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
487428
}
488429

489-
inline Value getStackPointer(RewriterBase &rewriter,
490-
FunctionOpInterface funcOp) {
491-
// See NOTE: [Additional Function Arguments]
492-
if (!isKernel(funcOp)) {
493-
return funcOp.getArgument(funcOp.getNumArguments() - 2);
494-
}
495-
496-
auto mod = funcOp->getParentOfType<ModuleOp>();
497-
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
498-
assert(globalBase);
499-
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
500-
}
501-
502-
inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
503-
const TargetInfoBase &targetInfo,
504-
FunctionOpInterface funcOp,
505-
Value allocOffset = {}) {
506-
// See NOTE: [Additional Function Arguments]
507-
if (!isKernel(funcOp)) {
508-
// Base for this function
509-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
510-
if (!allocOffset) {
511-
return gmemBase;
512-
}
513-
514-
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
515-
auto b = TritonLLVMOpBuilder(loc, rewriter);
516-
return b.gep(ptrTy, i8_ty, gmemBase, allocOffset);
517-
}
518-
519-
// Base for entire kernel
520-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
521-
522-
ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
523-
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
524-
"ttg.global_scratch_memory_size");
525-
if (!allocSizeAttr) {
526-
return gmemBase;
527-
}
430+
Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp);
528431

529-
Value gridIdx[3];
530-
Value gridDim[2];
531-
for (int k = 0; k < 3; ++k) {
532-
gridIdx[k] = rewriter.create<GetProgramIdOp>(loc, k);
533-
}
534-
for (int k = 0; k < 2; ++k) {
535-
gridDim[k] = rewriter.create<GetNumProgramsOp>(loc, k);
536-
}
432+
Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
433+
const TargetInfoBase &targetInfo,
434+
FunctionOpInterface funcOp, Value allocOffset);
537435

538-
auto b = TritonLLVMOpBuilder(loc, rewriter);
539-
Value linearId = gridIdx[2];
540-
for (int k = 0; k < 2; ++k) {
541-
linearId = b.add(gridIdx[1 - k], b.mul(linearId, gridDim[1 - k]));
542-
}
543-
auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
544-
if (numCTAs > 1) {
545-
linearId = b.mul(linearId, b.i32_val(numCTAs));
546-
linearId = b.add(linearId, targetInfo.getClusterCTAId(rewriter, loc));
547-
}
548-
549-
auto allocSize = allocSizeAttr.getValue().getZExtValue();
550-
551-
Value offset = b.mul(linearId, b.i32_val(allocSize));
552-
if (allocOffset) {
553-
offset = b.add(offset, allocOffset);
554-
}
555-
556-
auto *ctx = rewriter.getContext();
557-
auto res =
558-
b.gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset);
559-
return res;
560-
}
561-
562-
inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
563-
const TargetInfoBase &target, Operation *op) {
564-
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
565-
target.getSharedAddressSpace());
566-
auto func = op->template getParentOfType<FunctionOpInterface>();
567-
if (!func)
568-
func = cast<FunctionOpInterface>(op);
569-
570-
assert(op->hasAttr("allocation.offset"));
571-
size_t offset = cast<IntegerAttr>(op->getAttr("allocation.offset"))
572-
.getValue()
573-
.getZExtValue();
574-
auto b = TritonLLVMOpBuilder(loc, rewriter);
575-
Value offVal = b.i32_val(offset);
576-
Value base =
577-
b.gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
578-
return base;
579-
}
436+
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
437+
const TargetInfoBase &target, Operation *op);
580438

581439
// -----------------------------------------------------------------------
582440
// MXFP utilities
@@ -619,16 +477,8 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
619477
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
620478
using ::mlir::triton::gpu::SliceEncodingAttr;
621479

622-
inline Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
623-
ArrayRef<Value> strides) {
624-
assert(offsets.size() == strides.size());
625-
auto b = TritonLLVMOpBuilder(loc, rewriter);
626-
Value ret = b.i32_val(0);
627-
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
628-
ret = b.add(ret, b.mul(offset, stride));
629-
}
630-
return ret;
631-
}
480+
Value dot(RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
481+
ArrayRef<Value> strides);
632482

633483
/// Extend 2d shared object to 3d.
634484
///
@@ -720,91 +570,17 @@ SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
720570

721571
Value packLLVector(Location loc, ValueRange vals, RewriterBase &rewriter);
722572

723-
inline std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp) {
724-
switch (atomicOp) {
725-
case RMWOp::AND:
726-
return LLVM::AtomicBinOp::_and;
727-
case RMWOp::OR:
728-
return LLVM::AtomicBinOp::_or;
729-
case RMWOp::XOR:
730-
return LLVM::AtomicBinOp::_xor;
731-
case RMWOp::ADD:
732-
return LLVM::AtomicBinOp::add;
733-
case RMWOp::FADD:
734-
return LLVM::AtomicBinOp::fadd;
735-
case RMWOp::MAX:
736-
return LLVM::AtomicBinOp::max;
737-
case RMWOp::MIN:
738-
return LLVM::AtomicBinOp::min;
739-
case RMWOp::UMAX:
740-
return LLVM::AtomicBinOp::umax;
741-
case RMWOp::UMIN:
742-
return LLVM::AtomicBinOp::umin;
743-
case RMWOp::XCHG:
744-
return LLVM::AtomicBinOp::xchg;
745-
default:
746-
return {};
747-
}
748-
}
573+
std::optional<LLVM::AtomicBinOp> matchAtomicOp(RMWOp atomicOp);
749574

750-
inline std::optional<LLVM::AtomicOrdering>
751-
getMemoryOrdering(MemSemantic memOrdering) {
752-
switch (memOrdering) {
753-
case MemSemantic::RELAXED:
754-
return LLVM::AtomicOrdering::monotonic;
755-
case MemSemantic::ACQUIRE:
756-
return LLVM::AtomicOrdering::acquire;
757-
case MemSemantic::RELEASE:
758-
return LLVM::AtomicOrdering::release;
759-
case MemSemantic::ACQUIRE_RELEASE:
760-
return LLVM::AtomicOrdering::acq_rel;
761-
default:
762-
return {};
763-
}
764-
}
575+
std::optional<LLVM::AtomicOrdering> getMemoryOrdering(MemSemantic memOrdering);
765576

766-
inline bool
767-
isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
768-
ArrayRef<int64_t> allocShape,
769-
triton::gpu::SharedEncodingTrait sharedEnc) {
770-
auto rank = shape.size();
771-
auto swizzledLayout =
772-
dyn_cast<triton::gpu::SwizzledSharedEncodingAttr>(sharedEnc);
773-
auto nvmmaLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(sharedEnc);
774-
bool noSwizzling = (swizzledLayout && swizzledLayout.getMaxPhase() == 1) ||
775-
(nvmmaLayout && nvmmaLayout.getSwizzlingByteWidth() == 0);
776-
return /*no swizzling*/ noSwizzling ||
777-
/*swizzling but same shape*/ shape == allocShape ||
778-
/*swizzling and rank-reduced and rank >= 2*/
779-
(shape == allocShape.take_back(rank) && rank >= 2);
780-
}
577+
bool isSimpleSharedMemoryAccess(ArrayRef<int64_t> shape,
578+
ArrayRef<int64_t> allocShape,
579+
triton::gpu::SharedEncodingTrait sharedEnc);
781580

782-
inline llvm::MapVector<StringAttr, int32_t>
783-
getAllFreeVarMasks(MLIRContext *ctx) {
784-
// Mask where all elements are redundant
785-
auto kReg = str_attr("reg");
786-
auto kLane = str_attr("lane");
787-
auto kWarp = str_attr("warp");
788-
auto kBlock = str_attr("block");
581+
llvm::MapVector<StringAttr, int32_t> getAllFreeVarMasks(MLIRContext *ctx);
789582

790-
int32_t fullMask = -1;
791-
llvm::MapVector<StringAttr, int32_t> ret;
792-
for (auto dimName : {kReg, kLane, kWarp, kBlock}) {
793-
ret[dimName] = fullMask;
794-
}
795-
return ret;
796-
}
797-
798-
inline llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type) {
799-
auto ctx = type.getContext();
800-
auto tensorTy = dyn_cast<RankedTensorType>(type);
801-
if (!tensorTy) {
802-
return getAllFreeVarMasks(ctx);
803-
}
804-
auto ll =
805-
triton::gpu::toLinearLayout(tensorTy.getShape(), tensorTy.getEncoding());
806-
return ll.getFreeVariableMasks();
807-
}
583+
llvm::MapVector<StringAttr, int32_t> getFreeVariableMasks(Type type);
808584

809585
inline bool isCanonicalIndex(unsigned index, unsigned freeVarMask) {
810586
return (index & freeVarMask) == 0;

0 commit comments

Comments
 (0)