@@ -706,110 +706,6 @@ lowerLocalLdSt(Location loc, MLIRContext *ctx,
706706 maybeMaxVecElems, localLoadOp);
707707}
708708
709- bool emitTransferBetweenRegistersAndShared (
710- LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy,
711- std::optional<int32_t > maxVecElems, const SharedMemoryObject &smemObj,
712- Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
713- Value laneId, Value warpId,
714- std::function<void (VectorType, Value /* shmemAddr*/ )> perVectorCallback) {
715- MLIRContext *ctx = rewriter.getContext ();
716- auto b = TritonLLVMOpBuilder (loc, rewriter);
717-
718- StringAttr kBlock = str_attr (" block" );
719- StringAttr kRegister = str_attr (" register" );
720- StringAttr kLane = str_attr (" lane" );
721- StringAttr kWarp = str_attr (" warp" );
722- StringAttr kOffset = str_attr (" offset" );
723-
724- auto shape = sharedTy.getShape ();
725- auto paddedEnc =
726- dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedTy.getEncoding ());
727- LinearLayout regToSharedLayout = LinearLayout::empty ();
728- if (paddedEnc) {
729- const auto &sharedLL = paddedEnc.getLinearComponent ();
730- regToSharedLayout = regLayout.invertAndCompose (sharedLL);
731- } else {
732- auto sharedLL = triton::gpu::toLinearLayout (sharedTy);
733- regToSharedLayout = regLayout.invertAndCompose (sharedLL);
734- }
735-
736- // TODO(jlebar): We don't currently support loading from shared memory in a
737- // different CTA. We'd need to emit `mapa.shared::cluster` instructions.
738- if (regToSharedLayout.hasInDim (kBlock ) &&
739- regToSharedLayout.hasOutDim (kBlock ) &&
740- !regToSharedLayout.isTrivialOver ({kBlock })) {
741- return false ;
742- }
743-
744- // Determine how many consecutive registers map to consecutive shmem elements
745- // in out-dimension offsetN. This is our load instruction's vector width.
746- //
747- // It's OK if the vector width we choose here is wider than the hardware
748- // supports; LLVM will legalize it.
749- int vecElems =
750- std::min ({regToSharedLayout.getNumConsecutiveInOut (),
751- maxVecElems.value_or (std::numeric_limits<int >::max ())});
752- if (paddedEnc) {
753- vecElems = std::min (vecElems, int (paddedEnc.getMinInterval ()));
754- }
755-
756- auto withCTAOffset = triton::gpu::getNumCTAs (sharedTy.getEncoding ()) > 1 ;
757- Value blockId =
758- withCTAOffset ? target.getClusterCTAId (rewriter, loc) : b.i32_val (0 );
759-
760- int numElems = regToSharedLayout.getInDimSize (kRegister );
761- auto vecTy = vec_ty (elemLlvmTy, vecElems);
762- SmallVector<uint32_t > regIds;
763- for (int i = 0 ; i < numElems / vecElems; i++) {
764- regIds.push_back (i * vecElems);
765- }
766-
767- auto smemBase = smemObj.getBase ();
768-
769- auto indicesVec = applyLinearLayoutVec (loc, rewriter, regToSharedLayout,
770- {{kRegister , b.i32_val (0 )},
771- {kLane , laneId},
772- {kWarp , warpId},
773- {kBlock , blockId}},
774- regIds);
775-
776- // Compute affine offset given by memdesc_subslice
777- auto offset = smemObj.getShmemOffset (loc, rewriter, sharedTy);
778- SmallVector<Value> vecAddrVec;
779- for (auto &indices : indicesVec) {
780- Value smemOffset = indices[0 ].second ;
781- smemOffset = b.xor_ (smemOffset, offset);
782- if (paddedEnc) {
783- // Apply the offset needed for padding.
784- auto bitwidth = elemLlvmTy.getIntOrFloatBitWidth ();
785- Value padOffset = emitPadding (loc, rewriter, paddedEnc, bitwidth,
786- smemOffset, /* offsetInBytes=*/ false );
787- smemOffset = b.add (smemOffset, padOffset);
788- }
789- auto vecAddr = b.gep (smemBase.getType (), elemLlvmTy, smemBase, smemOffset,
790- LLVM::GEPNoWrapFlags::inbounds);
791- vecAddrVec.push_back (vecAddr);
792- }
793-
794- for (Value &vecAddr : vecAddrVec) {
795- perVectorCallback (vecTy, vecAddr);
796- }
797- return true ;
798- }
799-
800- bool emitTransferBetweenRegistersAndShared (
801- RankedTensorType registerTy, triton::gpu::MemDescType sharedTy,
802- Type elemLlvmTy, std::optional<int32_t > maxVecElems,
803- const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
804- const TargetInfoBase &target,
805- std::function<void (VectorType, Value /* shmemAddr*/ )> perVectorCallback) {
806- auto regLayout = triton::gpu::toLinearLayout (registerTy);
807- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
808- return emitTransferBetweenRegistersAndShared (
809- regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
810- target, laneId, warpId, perVectorCallback);
811- }
812-
813709SmallVector<Value> unpackLLElements (Location loc, Value llvmStruct,
814710 RewriterBase &rewriter) {
815711 assert (bool (llvmStruct) && " can not unpack null values" );
0 commit comments