@@ -338,46 +338,21 @@ using namespace mlir::triton;
338338
339339class SharedMemoryObject {
340340public:
341- SharedMemoryObject (Value base, Type baseElemType, ArrayRef<Value> offsets)
342- : base(base), baseElemType(baseElemType),
343- offsets (offsets.begin(), offsets.end()) {}
341+ SharedMemoryObject (Value base, Type baseElemType, ArrayRef<Value> offsets);
344342
345343 SharedMemoryObject (Value base, Type baseElemType, int64_t rank, Location loc,
346- RewriterBase &rewriter)
347- : base(base), baseElemType(baseElemType) {
348- auto b = TritonLLVMOpBuilder (loc, rewriter);
349- offsets.append (rank, b.i32_val (0 ));
350- }
344+ RewriterBase &rewriter);
351345
352346 SmallVector<Value> getOffsets () const { return offsets; }
353347 Value getBase () const { return base; }
354348 Type getBaseElemType () const { return baseElemType; }
355349
356- SmallVector<Value> getElems () const {
357- SmallVector<Value> elems;
358- elems.push_back (base);
359- elems.append (offsets.begin (), offsets.end ());
360- return elems;
361- }
350+ SmallVector<Value> getElems () const ;
362351
363- SmallVector<Type> getTypes () const {
364- SmallVector<Type> types;
365- types.push_back (base.getType ());
366- types.append (offsets.size (), IntegerType::get (base.getContext (), 32 ));
367- return types;
368- }
352+ SmallVector<Type> getTypes () const ;
369353
370354 SmallVector<Value> getStrides (triton::gpu::MemDescType memDesc, Location loc,
371- RewriterBase &rewriter) const {
372- auto allocShape = memDesc.getAllocShape ();
373- auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA (
374- memDesc.getEncoding (), allocShape);
375- auto layoutOrder = triton::gpu::getOrder (memDesc);
376- auto allocStrides = SharedMemoryObject::getStridesForShape (
377- allocShapePerCTA, layoutOrder, loc, rewriter);
378- return SmallVector<Value>(allocStrides.end () - offsets.size (),
379- allocStrides.end ());
380- }
355+ RewriterBase &rewriter) const ;
381356
382357 // TODO(Keren): deprecate the method once AMD backend has cleaned up
383358 Value getCSwizzleOffset (int dim) const {
@@ -386,50 +361,16 @@ class SharedMemoryObject {
386361 }
387362
388363 // TODO(Keren): deprecate the method once AMD backend has cleaned up
389- Value getBaseBeforeSlice (int dim, Location loc,
390- RewriterBase &rewriter) const {
391- auto b = TritonLLVMOpBuilder (loc, rewriter);
392- Value cSwizzleOffset = getCSwizzleOffset (dim);
393- Value offset = b.sub (b.i32_val (0 ), cSwizzleOffset);
394- Type type = base.getType ();
395- return b.gep (type, baseElemType, base, offset);
396- }
364+ Value getBaseBeforeSlice (int dim, Location loc, RewriterBase &rewriter) const ;
397365
398366private:
399- static SmallVector<unsigned >
400- getOrderForShape (ArrayRef<int64_t > shape, ArrayRef<unsigned > layoutOrder) {
401- SmallVector<unsigned > order (shape.size ());
402- // Default minor-to-major order
403- std::iota (order.rbegin (), order.rend (), 0 );
404- if (layoutOrder.size () > 0 ) {
405- // If a layout order is provided, we assume it specifies the order in
406- // which the dimensions are first accessed, and unspecified dimensions
407- // retain the minor-to-major order. For example, if order = [2, 1, 0] and
408- // layoutOrder = [0, 1], we need to shift `layoutOrder`
409- // by -1 (move them right). The resulting order will then be [1, 2, 0].
410- int rankDiff = layoutOrder.size () - shape.size ();
411- auto minRank = std::min<size_t >(shape.size (), layoutOrder.size ());
412- for (size_t i = 0 ; i < minRank; ++i)
413- order[i] = layoutOrder[i] - rankDiff;
414- }
415- assert (isPermutationOfIota (order) && " Invalid order" );
416- return order;
417- }
367+ static SmallVector<unsigned > getOrderForShape (ArrayRef<int64_t > shape,
368+ ArrayRef<unsigned > layoutOrder);
418369
419370 static SmallVector<Value> getStridesForShape (ArrayRef<int64_t > shape,
420371 ArrayRef<unsigned > layoutOrder,
421372 Location loc,
422- RewriterBase &rewriter) {
423- SmallVector<Value> strides (shape.size ());
424- auto order = SharedMemoryObject::getOrderForShape (shape, layoutOrder);
425- int64_t stride = 1 ;
426- auto b = TritonLLVMOpBuilder (loc, rewriter);
427- for (auto idx : order) {
428- strides[idx] = b.i32_val (stride);
429- stride *= shape[idx];
430- }
431- return strides;
432- }
373+ RewriterBase &rewriter);
433374
434375 Value base; // i32 ptr. The start address of the shared memory object.
435376 Type baseElemType;
@@ -486,97 +427,14 @@ inline bool isKernel(FunctionOpInterface funcOp) {
486427 return funcOp.getVisibility () == SymbolTable::Visibility::Public;
487428}
488429
489- inline Value getStackPointer (RewriterBase &rewriter,
490- FunctionOpInterface funcOp) {
491- // See NOTE: [Additional Function Arguments]
492- if (!isKernel (funcOp)) {
493- return funcOp.getArgument (funcOp.getNumArguments () - 2 );
494- }
495-
496- auto mod = funcOp->getParentOfType <ModuleOp>();
497- auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol (" global_smem" ));
498- assert (globalBase);
499- return rewriter.create <LLVM::AddressOfOp>(funcOp.getLoc (), globalBase);
500- }
501-
502- inline Value getGlobalScratchPtr (Location loc, RewriterBase &rewriter,
503- const TargetInfoBase &targetInfo,
504- FunctionOpInterface funcOp,
505- Value allocOffset = {}) {
506- // See NOTE: [Additional Function Arguments]
507- if (!isKernel (funcOp)) {
508- // Base for this function
509- auto gmemBase = funcOp.getArgument (funcOp.getNumArguments () - 1 );
510- if (!allocOffset) {
511- return gmemBase;
512- }
513-
514- auto ptrTy = mlir::LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
515- auto b = TritonLLVMOpBuilder (loc, rewriter);
516- return b.gep (ptrTy, i8_ty, gmemBase, allocOffset);
517- }
518-
519- // Base for entire kernel
520- auto gmemBase = funcOp.getArgument (funcOp.getNumArguments () - 1 );
521-
522- ModuleOp mod = funcOp.getOperation ()->getParentOfType <ModuleOp>();
523- auto allocSizeAttr = mod.getOperation ()->getAttrOfType <mlir::IntegerAttr>(
524- " ttg.global_scratch_memory_size" );
525- if (!allocSizeAttr) {
526- return gmemBase;
527- }
430+ Value getStackPointer (RewriterBase &rewriter, FunctionOpInterface funcOp);
528431
529- Value gridIdx[3 ];
530- Value gridDim[2 ];
531- for (int k = 0 ; k < 3 ; ++k) {
532- gridIdx[k] = rewriter.create <GetProgramIdOp>(loc, k);
533- }
534- for (int k = 0 ; k < 2 ; ++k) {
535- gridDim[k] = rewriter.create <GetNumProgramsOp>(loc, k);
536- }
432+ Value getGlobalScratchPtr (Location loc, RewriterBase &rewriter,
433+ const TargetInfoBase &targetInfo,
434+ FunctionOpInterface funcOp, Value allocOffset);
537435
538- auto b = TritonLLVMOpBuilder (loc, rewriter);
539- Value linearId = gridIdx[2 ];
540- for (int k = 0 ; k < 2 ; ++k) {
541- linearId = b.add (gridIdx[1 - k], b.mul (linearId, gridDim[1 - k]));
542- }
543- auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs (mod);
544- if (numCTAs > 1 ) {
545- linearId = b.mul (linearId, b.i32_val (numCTAs));
546- linearId = b.add (linearId, targetInfo.getClusterCTAId (rewriter, loc));
547- }
548-
549- auto allocSize = allocSizeAttr.getValue ().getZExtValue ();
550-
551- Value offset = b.mul (linearId, b.i32_val (allocSize));
552- if (allocOffset) {
553- offset = b.add (offset, allocOffset);
554- }
555-
556- auto *ctx = rewriter.getContext ();
557- auto res =
558- b.gep (mlir::LLVM::LLVMPointerType::get (ctx, 1 ), i8_ty, gmemBase, offset);
559- return res;
560- }
561-
562- inline Value getSharedMemoryBase (Location loc, RewriterBase &rewriter,
563- const TargetInfoBase &target, Operation *op) {
564- auto ptrTy = LLVM::LLVMPointerType::get (rewriter.getContext (),
565- target.getSharedAddressSpace ());
566- auto func = op->template getParentOfType <FunctionOpInterface>();
567- if (!func)
568- func = cast<FunctionOpInterface>(op);
569-
570- assert (op->hasAttr (" allocation.offset" ));
571- size_t offset = cast<IntegerAttr>(op->getAttr (" allocation.offset" ))
572- .getValue ()
573- .getZExtValue ();
574- auto b = TritonLLVMOpBuilder (loc, rewriter);
575- Value offVal = b.i32_val (offset);
576- Value base =
577- b.gep (ptrTy, i8_ty, LLVM::getStackPointer (rewriter, func), offVal);
578- return base;
579- }
436+ Value getSharedMemoryBase (Location loc, RewriterBase &rewriter,
437+ const TargetInfoBase &target, Operation *op);
580438
581439// -----------------------------------------------------------------------
582440// MXFP utilities
@@ -619,16 +477,8 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
619477using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
620478using ::mlir::triton::gpu::SliceEncodingAttr;
621479
622- inline Value dot (RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
623- ArrayRef<Value> strides) {
624- assert (offsets.size () == strides.size ());
625- auto b = TritonLLVMOpBuilder (loc, rewriter);
626- Value ret = b.i32_val (0 );
627- for (auto [offset, stride] : llvm::zip (offsets, strides)) {
628- ret = b.add (ret, b.mul (offset, stride));
629- }
630- return ret;
631- }
480+ Value dot (RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
481+ ArrayRef<Value> strides);
632482
633483// / Extend 2d shared object to 3d.
634484// /
@@ -720,91 +570,17 @@ SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
720570
721571Value packLLVector (Location loc, ValueRange vals, RewriterBase &rewriter);
722572
723- inline std::optional<LLVM::AtomicBinOp> matchAtomicOp (RMWOp atomicOp) {
724- switch (atomicOp) {
725- case RMWOp::AND:
726- return LLVM::AtomicBinOp::_and;
727- case RMWOp::OR:
728- return LLVM::AtomicBinOp::_or;
729- case RMWOp::XOR:
730- return LLVM::AtomicBinOp::_xor;
731- case RMWOp::ADD:
732- return LLVM::AtomicBinOp::add;
733- case RMWOp::FADD:
734- return LLVM::AtomicBinOp::fadd;
735- case RMWOp::MAX:
736- return LLVM::AtomicBinOp::max;
737- case RMWOp::MIN:
738- return LLVM::AtomicBinOp::min;
739- case RMWOp::UMAX:
740- return LLVM::AtomicBinOp::umax;
741- case RMWOp::UMIN:
742- return LLVM::AtomicBinOp::umin;
743- case RMWOp::XCHG:
744- return LLVM::AtomicBinOp::xchg;
745- default :
746- return {};
747- }
748- }
573+ std::optional<LLVM::AtomicBinOp> matchAtomicOp (RMWOp atomicOp);
749574
750- inline std::optional<LLVM::AtomicOrdering>
751- getMemoryOrdering (MemSemantic memOrdering) {
752- switch (memOrdering) {
753- case MemSemantic::RELAXED:
754- return LLVM::AtomicOrdering::monotonic;
755- case MemSemantic::ACQUIRE:
756- return LLVM::AtomicOrdering::acquire;
757- case MemSemantic::RELEASE:
758- return LLVM::AtomicOrdering::release;
759- case MemSemantic::ACQUIRE_RELEASE:
760- return LLVM::AtomicOrdering::acq_rel;
761- default :
762- return {};
763- }
764- }
575+ std::optional<LLVM::AtomicOrdering> getMemoryOrdering (MemSemantic memOrdering);
765576
766- inline bool
767- isSimpleSharedMemoryAccess (ArrayRef<int64_t > shape,
768- ArrayRef<int64_t > allocShape,
769- triton::gpu::SharedEncodingTrait sharedEnc) {
770- auto rank = shape.size ();
771- auto swizzledLayout =
772- dyn_cast<triton::gpu::SwizzledSharedEncodingAttr>(sharedEnc);
773- auto nvmmaLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(sharedEnc);
774- bool noSwizzling = (swizzledLayout && swizzledLayout.getMaxPhase () == 1 ) ||
775- (nvmmaLayout && nvmmaLayout.getSwizzlingByteWidth () == 0 );
776- return /* no swizzling*/ noSwizzling ||
777- /* swizzling but same shape*/ shape == allocShape ||
778- /* swizzling and rank-reduced and rank >= 2*/
779- (shape == allocShape.take_back (rank) && rank >= 2 );
780- }
577+ bool isSimpleSharedMemoryAccess (ArrayRef<int64_t > shape,
578+ ArrayRef<int64_t > allocShape,
579+ triton::gpu::SharedEncodingTrait sharedEnc);
781580
782- inline llvm::MapVector<StringAttr, int32_t >
783- getAllFreeVarMasks (MLIRContext *ctx) {
784- // Mask where all elements are redundant
785- auto kReg = str_attr (" reg" );
786- auto kLane = str_attr (" lane" );
787- auto kWarp = str_attr (" warp" );
788- auto kBlock = str_attr (" block" );
581+ llvm::MapVector<StringAttr, int32_t > getAllFreeVarMasks (MLIRContext *ctx);
789582
790- int32_t fullMask = -1 ;
791- llvm::MapVector<StringAttr, int32_t > ret;
792- for (auto dimName : {kReg , kLane , kWarp , kBlock }) {
793- ret[dimName] = fullMask;
794- }
795- return ret;
796- }
797-
798- inline llvm::MapVector<StringAttr, int32_t > getFreeVariableMasks (Type type) {
799- auto ctx = type.getContext ();
800- auto tensorTy = dyn_cast<RankedTensorType>(type);
801- if (!tensorTy) {
802- return getAllFreeVarMasks (ctx);
803- }
804- auto ll =
805- triton::gpu::toLinearLayout (tensorTy.getShape (), tensorTy.getEncoding ());
806- return ll.getFreeVariableMasks ();
807- }
583+ llvm::MapVector<StringAttr, int32_t > getFreeVariableMasks (Type type);
808584
809585inline bool isCanonicalIndex (unsigned index, unsigned freeVarMask) {
810586 return (index & freeVarMask) == 0 ;
0 commit comments