1+ #include " ../TritonAMDGPUToLLVM/Utility.h"
12#include " Dialect/TritonAMDGPU/IR/Dialect.h"
23#include " TritonAMDGPUToLLVM/GCNAsmFormat.h"
34#include " mlir/Conversion/LLVMCommon/Pattern.h"
@@ -49,6 +50,7 @@ using namespace mlir::triton;
4950// clang-format on
5051
5152namespace {
53+
5254struct ExtractSliceOpConversion
5355 : public ConvertOpToLLVMPattern<amdgpu::ExtractSliceOp> {
5456 explicit ExtractSliceOpConversion (LLVMTypeConverter &typeConverter,
@@ -60,61 +62,61 @@ struct ExtractSliceOpConversion
6062 ConversionPatternRewriter &rewriter) const {
6163 Location loc = op->getLoc ();
6264 auto srcTy = cast<RankedTensorType>(op.getSource ().getType ());
63- auto srcLayout = srcTy. getEncoding ( );
65+ auto dstTy = cast<RankedTensorType>(op. getType () );
6466 auto srcShape = srcTy.getShape ();
65- auto resultTy = cast<RankedTensorType>(op.getType ());
66- auto vals = unpackLLElements (loc, adaptor.getSource (), rewriter);
67- auto elemsPerThread = triton::gpu::getElemsPerThread (srcTy);
68- auto contigPerThread = triton::gpu::getContigPerThread (srcTy);
69- auto totalContigPerThread = product<unsigned >(contigPerThread);
70- auto order = triton::gpu::getOrder (srcTy);
67+ auto dstShape = dstTy.getShape ();
7168
72- // Calculate valid total number of workers in each dimension
69+ auto vals = unpackLLElements (loc, adaptor. getSource (), rewriter);
7370 auto shapePerCTATile = triton::gpu::getShapePerCTATile (srcTy);
74- shapePerCTATile[0 ] =
75- std::min (static_cast <unsigned >(srcShape[0 ]), shapePerCTATile[0 ]);
76- shapePerCTATile[1 ] =
77- std::min (static_cast <unsigned >(srcShape[1 ]), shapePerCTATile[1 ]);
78-
79- // Rank == 2 checked in the verifier
80- SmallVector<int64_t , 2 > sizes;
81- for (auto i = 0 ; i < 2 ; ++i) {
82- sizes.push_back (resultTy.getDimSize (i));
83- }
71+ auto srcCTAShape = LLVM::AMD::multiDimElementwise<int64_t , unsigned >(
72+ srcShape, shapePerCTATile, std::divides<unsigned >());
73+ auto dstCTAShape = LLVM::AMD::multiDimElementwise<int64_t , unsigned >(
74+ dstShape, shapePerCTATile, std::divides<unsigned >());
8475
76+ auto numCTATiles = std::accumulate (dstCTAShape.begin (), dstCTAShape.end (),
77+ 1 , std::multiplies<>());
8578 auto offsets = op.getStaticOffsets ();
79+ auto firstTileCoordinate =
80+ LLVM::AMD::multiDimElementwise<int64_t , unsigned >(
81+ offsets, shapePerCTATile, std::divides<unsigned >());
8682
87- // Calculate offsets and sizes in terms of CTA units.
88- std::array<int64_t , 2 > CTAOffsets{offsets[0 ] / shapePerCTATile[0 ],
89- offsets[1 ] / shapePerCTATile[1 ]};
90- std::array<int64_t , 2 > CTASizes{sizes[0 ] / shapePerCTATile[0 ],
91- sizes[1 ] / shapePerCTATile[1 ]};
92- std::array<int64_t , 2 > CTAPerShape{srcShape[0 ] / shapePerCTATile[0 ],
93- srcShape[1 ] / shapePerCTATile[1 ]};
94-
95- // The diagram above illustrates the graphical representation of the
96- // skipElems, tensorStride, and lastIdx variables.
97- auto skipElems = CTAOffsets[order[1 ]] * (elemsPerThread[order[0 ]] *
98- contigPerThread[order[1 ]]) +
99- CTAOffsets[order[0 ]] * totalContigPerThread;
100- auto tensorStride =
101- (CTAPerShape[order[0 ]] - CTASizes[order[0 ]]) * totalContigPerThread;
102- auto lastIdx =
103- (CTAOffsets[order[1 ]] + CTASizes[order[1 ]] - 1 ) *
104- elemsPerThread[order[0 ]] * contigPerThread[order[1 ]] +
105- (CTAOffsets[order[0 ]] + CTASizes[order[0 ]]) * totalContigPerThread;
106-
107- assert (lastIdx <= vals.size ());
83+ Attribute srcEncoding = srcTy.getEncoding ();
84+ Attribute dstEncoding = dstTy.getEncoding ();
85+ auto linearLayoutSrc = triton::gpu::toLinearLayout (srcShape, srcEncoding);
86+ auto linearLayoutDst = triton::gpu::toLinearLayout (dstShape, dstEncoding);
10887
88+ auto srcCTAOrder =
89+ LLVM::AMD::getCTATileOrder (srcTy.getContext (), linearLayoutSrc);
90+ auto dstCTAOrder =
91+ LLVM::AMD::getCTATileOrder (srcTy.getContext (), linearLayoutDst);
92+
93+ unsigned elemsPerThreadPerCTA =
94+ triton::gpu::getTotalElemsPerThread (srcTy) /
95+ std::accumulate (srcCTAShape.begin (), srcCTAShape.end (), 1 ,
96+ std::multiplies<>());
97+
98+ // 1. Process CTA tiles in the destination tensor according to the
99+ // destination's linear layout order of CTA tiles.
100+ // 2. For each tile position in the destination tensor, compute its
101+ // corresponding position in the source tensor.
102+ // 3. Copy the values from the source tile to the destination slice.
109103 SmallVector<Value> resultVals;
110- for (int i = skipElems; i < lastIdx; i += tensorStride) {
111- for (int j = 0 ; j < totalContigPerThread * CTASizes[order[0 ]]; ++j, ++i) {
112- assert (i < lastIdx);
113- resultVals.push_back (vals[i]);
104+ for (size_t i = 0 ; i < numCTATiles; i++) {
105+ auto coordInDstTensor =
106+ mlir::LLVM::delinearize (i, dstCTAShape, dstCTAOrder);
107+ auto coordInSrcTensor =
108+ LLVM::AMD::multiDimElementwise<unsigned , unsigned >(
109+ coordInDstTensor, firstTileCoordinate, std::plus<unsigned >());
110+ auto linearIdxInSrcTensor =
111+ mlir::LLVM::linearize (coordInSrcTensor, srcCTAShape, srcCTAOrder);
112+
113+ for (size_t j = 0 ; j < elemsPerThreadPerCTA; j++) {
114+ resultVals.push_back (
115+ vals[linearIdxInSrcTensor * elemsPerThreadPerCTA + j]);
114116 }
115117 }
116118 Value ret = packLLElements (loc, this ->getTypeConverter (), resultVals,
117- rewriter, resultTy );
119+ rewriter, dstTy );
118120
119121 rewriter.replaceOp (op, ret);
120122 return success ();
@@ -124,11 +126,7 @@ struct ExtractSliceOpConversion
124126 matchAndRewrite (amdgpu::ExtractSliceOp op, OpAdaptor adaptor,
125127 ConversionPatternRewriter &rewriter) const override {
126128 auto srcTy = op.getSource ().getType ();
127- if (isa<BlockedEncodingAttr, AMDMfmaEncodingAttr>(
128- op.getSource ().getType ().getEncoding ())) {
129- return processLayout (op, adaptor, rewriter);
130- }
131- return failure ();
129+ return processLayout (op, adaptor, rewriter);
132130 }
133131};
134132} // namespace
0 commit comments