@@ -338,46 +338,21 @@ using namespace mlir::triton;
338
338
339
339
class SharedMemoryObject {
340
340
public:
341
- SharedMemoryObject (Value base, Type baseElemType, ArrayRef<Value> offsets)
342
- : base(base), baseElemType(baseElemType),
343
- offsets (offsets.begin(), offsets.end()) {}
341
+ SharedMemoryObject (Value base, Type baseElemType, ArrayRef<Value> offsets);
344
342
345
343
SharedMemoryObject (Value base, Type baseElemType, int64_t rank, Location loc,
346
- RewriterBase &rewriter)
347
- : base(base), baseElemType(baseElemType) {
348
- auto b = TritonLLVMOpBuilder (loc, rewriter);
349
- offsets.append (rank, b.i32_val (0 ));
350
- }
344
+ RewriterBase &rewriter);
351
345
352
346
SmallVector<Value> getOffsets () const { return offsets; }
353
347
Value getBase () const { return base; }
354
348
Type getBaseElemType () const { return baseElemType; }
355
349
356
- SmallVector<Value> getElems () const {
357
- SmallVector<Value> elems;
358
- elems.push_back (base);
359
- elems.append (offsets.begin (), offsets.end ());
360
- return elems;
361
- }
350
+ SmallVector<Value> getElems () const ;
362
351
363
- SmallVector<Type> getTypes () const {
364
- SmallVector<Type> types;
365
- types.push_back (base.getType ());
366
- types.append (offsets.size (), IntegerType::get (base.getContext (), 32 ));
367
- return types;
368
- }
352
+ SmallVector<Type> getTypes () const ;
369
353
370
354
SmallVector<Value> getStrides (triton::gpu::MemDescType memDesc, Location loc,
371
- RewriterBase &rewriter) const {
372
- auto allocShape = memDesc.getAllocShape ();
373
- auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA (
374
- memDesc.getEncoding (), allocShape);
375
- auto layoutOrder = triton::gpu::getOrder (memDesc);
376
- auto allocStrides = SharedMemoryObject::getStridesForShape (
377
- allocShapePerCTA, layoutOrder, loc, rewriter);
378
- return SmallVector<Value>(allocStrides.end () - offsets.size (),
379
- allocStrides.end ());
380
- }
355
+ RewriterBase &rewriter) const ;
381
356
382
357
// TODO(Keren): deprecate the method once AMD backend has cleaned up
383
358
Value getCSwizzleOffset (int dim) const {
@@ -386,50 +361,16 @@ class SharedMemoryObject {
386
361
}
387
362
388
363
// TODO(Keren): deprecate the method once AMD backend has cleaned up
389
- Value getBaseBeforeSlice (int dim, Location loc,
390
- RewriterBase &rewriter) const {
391
- auto b = TritonLLVMOpBuilder (loc, rewriter);
392
- Value cSwizzleOffset = getCSwizzleOffset (dim);
393
- Value offset = b.sub (b.i32_val (0 ), cSwizzleOffset);
394
- Type type = base.getType ();
395
- return b.gep (type, baseElemType, base, offset);
396
- }
364
+ Value getBaseBeforeSlice (int dim, Location loc, RewriterBase &rewriter) const ;
397
365
398
366
private:
399
- static SmallVector<unsigned >
400
- getOrderForShape (ArrayRef<int64_t > shape, ArrayRef<unsigned > layoutOrder) {
401
- SmallVector<unsigned > order (shape.size ());
402
- // Default minor-to-major order
403
- std::iota (order.rbegin (), order.rend (), 0 );
404
- if (layoutOrder.size () > 0 ) {
405
- // If a layout order is provided, we assume it specifies the order in
406
- // which the dimensions are first accessed, and unspecified dimensions
407
- // retain the minor-to-major order. For example, if order = [2, 1, 0] and
408
- // layoutOrder = [0, 1], we need to shift `layoutOrder`
409
- // by -1 (move them right). The resulting order will then be [1, 2, 0].
410
- int rankDiff = layoutOrder.size () - shape.size ();
411
- auto minRank = std::min<size_t >(shape.size (), layoutOrder.size ());
412
- for (size_t i = 0 ; i < minRank; ++i)
413
- order[i] = layoutOrder[i] - rankDiff;
414
- }
415
- assert (isPermutationOfIota (order) && " Invalid order" );
416
- return order;
417
- }
367
+ static SmallVector<unsigned > getOrderForShape (ArrayRef<int64_t > shape,
368
+ ArrayRef<unsigned > layoutOrder);
418
369
419
370
static SmallVector<Value> getStridesForShape (ArrayRef<int64_t > shape,
420
371
ArrayRef<unsigned > layoutOrder,
421
372
Location loc,
422
- RewriterBase &rewriter) {
423
- SmallVector<Value> strides (shape.size ());
424
- auto order = SharedMemoryObject::getOrderForShape (shape, layoutOrder);
425
- int64_t stride = 1 ;
426
- auto b = TritonLLVMOpBuilder (loc, rewriter);
427
- for (auto idx : order) {
428
- strides[idx] = b.i32_val (stride);
429
- stride *= shape[idx];
430
- }
431
- return strides;
432
- }
373
+ RewriterBase &rewriter);
433
374
434
375
Value base; // i32 ptr. The start address of the shared memory object.
435
376
Type baseElemType;
@@ -486,97 +427,14 @@ inline bool isKernel(FunctionOpInterface funcOp) {
486
427
return funcOp.getVisibility () == SymbolTable::Visibility::Public;
487
428
}
488
429
489
- inline Value getStackPointer (RewriterBase &rewriter,
490
- FunctionOpInterface funcOp) {
491
- // See NOTE: [Additional Function Arguments]
492
- if (!isKernel (funcOp)) {
493
- return funcOp.getArgument (funcOp.getNumArguments () - 2 );
494
- }
495
-
496
- auto mod = funcOp->getParentOfType <ModuleOp>();
497
- auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol (" global_smem" ));
498
- assert (globalBase);
499
- return rewriter.create <LLVM::AddressOfOp>(funcOp.getLoc (), globalBase);
500
- }
501
-
502
- inline Value getGlobalScratchPtr (Location loc, RewriterBase &rewriter,
503
- const TargetInfoBase &targetInfo,
504
- FunctionOpInterface funcOp,
505
- Value allocOffset = {}) {
506
- // See NOTE: [Additional Function Arguments]
507
- if (!isKernel (funcOp)) {
508
- // Base for this function
509
- auto gmemBase = funcOp.getArgument (funcOp.getNumArguments () - 1 );
510
- if (!allocOffset) {
511
- return gmemBase;
512
- }
513
-
514
- auto ptrTy = mlir::LLVM::LLVMPointerType::get (rewriter.getContext (), 1 );
515
- auto b = TritonLLVMOpBuilder (loc, rewriter);
516
- return b.gep (ptrTy, i8_ty, gmemBase, allocOffset);
517
- }
518
-
519
- // Base for entire kernel
520
- auto gmemBase = funcOp.getArgument (funcOp.getNumArguments () - 1 );
521
-
522
- ModuleOp mod = funcOp.getOperation ()->getParentOfType <ModuleOp>();
523
- auto allocSizeAttr = mod.getOperation ()->getAttrOfType <mlir::IntegerAttr>(
524
- " ttg.global_scratch_memory_size" );
525
- if (!allocSizeAttr) {
526
- return gmemBase;
527
- }
430
+ Value getStackPointer (RewriterBase &rewriter, FunctionOpInterface funcOp);
528
431
529
- Value gridIdx[3 ];
530
- Value gridDim[2 ];
531
- for (int k = 0 ; k < 3 ; ++k) {
532
- gridIdx[k] = rewriter.create <GetProgramIdOp>(loc, k);
533
- }
534
- for (int k = 0 ; k < 2 ; ++k) {
535
- gridDim[k] = rewriter.create <GetNumProgramsOp>(loc, k);
536
- }
432
+ Value getGlobalScratchPtr (Location loc, RewriterBase &rewriter,
433
+ const TargetInfoBase &targetInfo,
434
+ FunctionOpInterface funcOp, Value allocOffset);
537
435
538
- auto b = TritonLLVMOpBuilder (loc, rewriter);
539
- Value linearId = gridIdx[2 ];
540
- for (int k = 0 ; k < 2 ; ++k) {
541
- linearId = b.add (gridIdx[1 - k], b.mul (linearId, gridDim[1 - k]));
542
- }
543
- auto numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs (mod);
544
- if (numCTAs > 1 ) {
545
- linearId = b.mul (linearId, b.i32_val (numCTAs));
546
- linearId = b.add (linearId, targetInfo.getClusterCTAId (rewriter, loc));
547
- }
548
-
549
- auto allocSize = allocSizeAttr.getValue ().getZExtValue ();
550
-
551
- Value offset = b.mul (linearId, b.i32_val (allocSize));
552
- if (allocOffset) {
553
- offset = b.add (offset, allocOffset);
554
- }
555
-
556
- auto *ctx = rewriter.getContext ();
557
- auto res =
558
- b.gep (mlir::LLVM::LLVMPointerType::get (ctx, 1 ), i8_ty, gmemBase, offset);
559
- return res;
560
- }
561
-
562
- inline Value getSharedMemoryBase (Location loc, RewriterBase &rewriter,
563
- const TargetInfoBase &target, Operation *op) {
564
- auto ptrTy = LLVM::LLVMPointerType::get (rewriter.getContext (),
565
- target.getSharedAddressSpace ());
566
- auto func = op->template getParentOfType <FunctionOpInterface>();
567
- if (!func)
568
- func = cast<FunctionOpInterface>(op);
569
-
570
- assert (op->hasAttr (" allocation.offset" ));
571
- size_t offset = cast<IntegerAttr>(op->getAttr (" allocation.offset" ))
572
- .getValue ()
573
- .getZExtValue ();
574
- auto b = TritonLLVMOpBuilder (loc, rewriter);
575
- Value offVal = b.i32_val (offset);
576
- Value base =
577
- b.gep (ptrTy, i8_ty, LLVM::getStackPointer (rewriter, func), offVal);
578
- return base;
579
- }
436
+ Value getSharedMemoryBase (Location loc, RewriterBase &rewriter,
437
+ const TargetInfoBase &target, Operation *op);
580
438
581
439
// -----------------------------------------------------------------------
582
440
// MXFP utilities
@@ -619,16 +477,8 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
619
477
using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
620
478
using ::mlir::triton::gpu::SliceEncodingAttr;
621
479
622
- inline Value dot (RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
623
- ArrayRef<Value> strides) {
624
- assert (offsets.size () == strides.size ());
625
- auto b = TritonLLVMOpBuilder (loc, rewriter);
626
- Value ret = b.i32_val (0 );
627
- for (auto [offset, stride] : llvm::zip (offsets, strides)) {
628
- ret = b.add (ret, b.mul (offset, stride));
629
- }
630
- return ret;
631
- }
480
+ Value dot (RewriterBase &rewriter, Location loc, ArrayRef<Value> offsets,
481
+ ArrayRef<Value> strides);
632
482
633
483
// / Extend 2d shared object to 3d.
634
484
// /
@@ -720,91 +570,17 @@ SmallVector<Value> unpackLLVector(Location loc, Value llvmVec,
720
570
721
571
Value packLLVector (Location loc, ValueRange vals, RewriterBase &rewriter);
722
572
723
- inline std::optional<LLVM::AtomicBinOp> matchAtomicOp (RMWOp atomicOp) {
724
- switch (atomicOp) {
725
- case RMWOp::AND:
726
- return LLVM::AtomicBinOp::_and;
727
- case RMWOp::OR:
728
- return LLVM::AtomicBinOp::_or;
729
- case RMWOp::XOR:
730
- return LLVM::AtomicBinOp::_xor;
731
- case RMWOp::ADD:
732
- return LLVM::AtomicBinOp::add;
733
- case RMWOp::FADD:
734
- return LLVM::AtomicBinOp::fadd;
735
- case RMWOp::MAX:
736
- return LLVM::AtomicBinOp::max;
737
- case RMWOp::MIN:
738
- return LLVM::AtomicBinOp::min;
739
- case RMWOp::UMAX:
740
- return LLVM::AtomicBinOp::umax;
741
- case RMWOp::UMIN:
742
- return LLVM::AtomicBinOp::umin;
743
- case RMWOp::XCHG:
744
- return LLVM::AtomicBinOp::xchg;
745
- default :
746
- return {};
747
- }
748
- }
573
+ std::optional<LLVM::AtomicBinOp> matchAtomicOp (RMWOp atomicOp);
749
574
750
- inline std::optional<LLVM::AtomicOrdering>
751
- getMemoryOrdering (MemSemantic memOrdering) {
752
- switch (memOrdering) {
753
- case MemSemantic::RELAXED:
754
- return LLVM::AtomicOrdering::monotonic;
755
- case MemSemantic::ACQUIRE:
756
- return LLVM::AtomicOrdering::acquire;
757
- case MemSemantic::RELEASE:
758
- return LLVM::AtomicOrdering::release;
759
- case MemSemantic::ACQUIRE_RELEASE:
760
- return LLVM::AtomicOrdering::acq_rel;
761
- default :
762
- return {};
763
- }
764
- }
575
+ std::optional<LLVM::AtomicOrdering> getMemoryOrdering (MemSemantic memOrdering);
765
576
766
- inline bool
767
- isSimpleSharedMemoryAccess (ArrayRef<int64_t > shape,
768
- ArrayRef<int64_t > allocShape,
769
- triton::gpu::SharedEncodingTrait sharedEnc) {
770
- auto rank = shape.size ();
771
- auto swizzledLayout =
772
- dyn_cast<triton::gpu::SwizzledSharedEncodingAttr>(sharedEnc);
773
- auto nvmmaLayout = dyn_cast<triton::gpu::NVMMASharedEncodingAttr>(sharedEnc);
774
- bool noSwizzling = (swizzledLayout && swizzledLayout.getMaxPhase () == 1 ) ||
775
- (nvmmaLayout && nvmmaLayout.getSwizzlingByteWidth () == 0 );
776
- return /* no swizzling*/ noSwizzling ||
777
- /* swizzling but same shape*/ shape == allocShape ||
778
- /* swizzling and rank-reduced and rank >= 2*/
779
- (shape == allocShape.take_back (rank) && rank >= 2 );
780
- }
577
+ bool isSimpleSharedMemoryAccess (ArrayRef<int64_t > shape,
578
+ ArrayRef<int64_t > allocShape,
579
+ triton::gpu::SharedEncodingTrait sharedEnc);
781
580
782
- inline llvm::MapVector<StringAttr, int32_t >
783
- getAllFreeVarMasks (MLIRContext *ctx) {
784
- // Mask where all elements are redundant
785
- auto kReg = str_attr (" reg" );
786
- auto kLane = str_attr (" lane" );
787
- auto kWarp = str_attr (" warp" );
788
- auto kBlock = str_attr (" block" );
581
+ llvm::MapVector<StringAttr, int32_t > getAllFreeVarMasks (MLIRContext *ctx);
789
582
790
- int32_t fullMask = -1 ;
791
- llvm::MapVector<StringAttr, int32_t > ret;
792
- for (auto dimName : {kReg , kLane , kWarp , kBlock }) {
793
- ret[dimName] = fullMask;
794
- }
795
- return ret;
796
- }
797
-
798
- inline llvm::MapVector<StringAttr, int32_t > getFreeVariableMasks (Type type) {
799
- auto ctx = type.getContext ();
800
- auto tensorTy = dyn_cast<RankedTensorType>(type);
801
- if (!tensorTy) {
802
- return getAllFreeVarMasks (ctx);
803
- }
804
- auto ll =
805
- triton::gpu::toLinearLayout (tensorTy.getShape (), tensorTy.getEncoding ());
806
- return ll.getFreeVariableMasks ();
807
- }
583
+ llvm::MapVector<StringAttr, int32_t > getFreeVariableMasks (Type type);
808
584
809
585
inline bool isCanonicalIndex (unsigned index, unsigned freeVarMask) {
810
586
return (index & freeVarMask) == 0 ;
0 commit comments