99#include " mlir/Dialect/AMDGPU/Transforms/Passes.h"
1010
1111#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+ #include " mlir/Dialect/MemRef/IR/MemRef.h"
13+ #include " mlir/Dialect/SCF/IR/SCF.h"
1214#include " mlir/Dialect/Vector/IR/VectorOps.h"
1315#include " mlir/IR/BuiltinTypes.h"
1416#include " mlir/IR/PatternMatch.h"
@@ -108,6 +110,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
108110
109111 LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
110112 PatternRewriter &rewriter) const override {
113+ if (readOp->hasAttr (" amdgpu.transformed" ))
114+ return failure ();
111115
112116 bool requiresBroadcasting = false ;
113117 VectorType unbroadcastedVectorType;
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
117121 }
118122
119123 Location loc = readOp.getLoc ();
120- Value fill = rewriter.create <vector::SplatOp>(loc, unbroadcastedVectorType,
121- readOp.getPadding ());
122- Value load = rewriter.create <vector::LoadOp>(
123- loc, unbroadcastedVectorType, readOp.getSource (), readOp.getIndices ());
124- Value res = rewriter.create <arith::SelectOp>(loc, unbroadcastedVectorType,
125- readOp.getMask (), load, fill);
126-
127- // Insert a broadcasting op if required.
128- if (requiresBroadcasting) {
129- res = rewriter.create <vector::BroadcastOp>(loc, readOp.getVectorType (),
130- res);
124+ Value src = readOp.getSource ();
125+ MemRefType memRefType = cast<MemRefType>(src.getType ());
126+ ArrayRef<int64_t > shape = memRefType.getShape ();
127+
128+ Value linearIndex = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
129+ Value one = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
130+ Value stride = one;
131+
132+ // Compute the linear index by linearIndex += indices[i] * stride
133+ for (int i = shape.size () - 1 ; i >= 0 ; --i) {
134+ Value currentIndex = readOp.getIndices ()[i];
135+ Value strideIndexed =
136+ rewriter.create <arith::MulIOp>(loc, currentIndex, stride);
137+ linearIndex =
138+ rewriter.create <arith::AddIOp>(loc, linearIndex, strideIndexed);
139+
140+ if (i == 0 )
141+ break ;
142+
143+ // Update stride for the next dimension
144+ Value nextStride;
145+ if (shape[i] != ShapedType::kDynamic ) {
146+ nextStride = rewriter.create <arith::ConstantIndexOp>(loc, shape[i]);
147+ } else {
148+ nextStride = rewriter.create <memref::DimOp>(loc, src, i);
149+ }
150+ stride = rewriter.create <arith::MulIOp>(loc, stride, nextStride);
151+ }
152+
153+ // Add vector size offset to linear index
154+ VectorType vectorType = readOp.getVectorType ();
155+ int64_t vectorSize = vectorType.getNumElements ();
156+ Value vectorSizeOffset =
157+ rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
158+ Value upperBoundIndex =
159+ rewriter.create <arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
160+
161+ Value totalSize = one;
162+ for (size_t i = 0 ; i < shape.size (); ++i) {
163+ Value dimensionSize;
164+ if (shape[i] != ShapedType::kDynamic ) {
165+ dimensionSize = rewriter.create <arith::ConstantIndexOp>(loc, shape[i]);
166+ } else {
167+ dimensionSize = rewriter.create <memref::DimOp>(loc, src, i);
168+ }
169+ totalSize = rewriter.create <arith::MulIOp>(loc, totalSize, dimensionSize);
131170 }
132171
133- rewriter.replaceOp (readOp, res);
172+ Value isInBounds = rewriter.create <arith::CmpIOp>(
173+ loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
174+
175+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
176+ Value fill = builder.create <vector::SplatOp>(loc, unbroadcastedVectorType,
177+ readOp.getPadding ());
178+ Value load = builder.create <vector::LoadOp>(loc, unbroadcastedVectorType,
179+ readOp.getSource (),
180+ readOp.getIndices ());
181+ Value res = builder.create <arith::SelectOp>(loc, unbroadcastedVectorType,
182+ readOp.getMask (), load, fill);
183+
184+ // Insert a broadcasting op if required.
185+ if (requiresBroadcasting) {
186+ res = builder.create <vector::BroadcastOp>(loc, readOp.getVectorType (),
187+ res);
188+ }
189+ rewriter.create <scf::YieldOp>(loc, res);
190+ };
191+
192+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
193+ Operation *read = builder.clone (*readOp.getOperation ());
194+ read->setAttr (" amdgpu.transformed" , builder.getUnitAttr ());
195+ Value readResult = read->getResult (0 );
196+ builder.create <scf::YieldOp>(loc, readResult);
197+ };
198+
199+ auto ifOp =
200+ rewriter.create <scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
201+
202+ rewriter.replaceOp (readOp, ifOp);
134203
135204 return success ();
136205 }
0 commit comments