@@ -103,6 +103,23 @@ static LogicalResult transferPreconditions(
103103 return success ();
104104}
105105
106+ static Value createVectorLoadForMaskedLoad (OpBuilder &builder, Location loc,
107+ vector::TransferReadOp readOp,
108+ bool requiresBroadcasting,
109+ VectorType unbroadcastedVectorType) {
110+ Value fill = builder.create <vector::SplatOp>(loc, unbroadcastedVectorType,
111+ readOp.getPadding ());
112+ Value load = builder.create <vector::LoadOp>(
113+ loc, unbroadcastedVectorType, readOp.getSource (), readOp.getIndices ());
114+ Value res = builder.create <arith::SelectOp>(loc, unbroadcastedVectorType,
115+ readOp.getMask (), load, fill);
116+ // Insert a broadcasting op if required.
117+ if (requiresBroadcasting) {
118+ res = builder.create <vector::BroadcastOp>(loc, readOp.getVectorType (), res);
119+ }
120+ return res;
121+ }
122+
106123namespace {
107124
108125struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
@@ -150,14 +167,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
150167 stride = rewriter.create <arith::MulIOp>(loc, stride, nextStride);
151168 }
152169
153- // Add vector size offset to linear index
154- VectorType vectorType = readOp.getVectorType ();
155- int64_t vectorSize = vectorType.getNumElements ();
156- Value vectorSizeOffset =
157- rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
158- Value upperBoundIndex =
159- rewriter.create <arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
160-
161170 Value totalSize = one;
162171 for (size_t i = 0 ; i < shape.size (); ++i) {
163172 Value dimensionSize;
@@ -169,35 +178,48 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
169178 totalSize = rewriter.create <arith::MulIOp>(loc, totalSize, dimensionSize);
170179 }
171180
172- Value isInBounds = rewriter.create <arith::CmpIOp>(
173- loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
181+ // delta = bufferSize - linearizedOffset
182+ // 1) check if delta < vectorSize
183+ VectorType vectorType = readOp.getVectorType ();
184+ int64_t vectorSize = vectorType.getNumElements ();
185+ Value vectorSizeOffset =
186+ rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
187+ Value delta = rewriter.create <arith::SubIOp>(loc, totalSize, linearIndex);
188+ Value isOutofBounds = rewriter.create <arith::CmpIOp>(
189+ loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
190+
191+ // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
192+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth ();
193+ Value deltaBytes = rewriter.create <arith::MulIOp>(
194+ loc, delta,
195+ rewriter.create <arith::ConstantIndexOp>(loc, elementBitWidth / 8 ));
196+ Value elementsPerWord = rewriter.create <arith::ConstantIndexOp>(
197+ loc, elementBitWidth < 32 ? 32 / elementBitWidth : 1 );
198+ Value isNotWordAligned = rewriter.create <arith::CmpIOp>(
199+ loc, arith::CmpIPredicate::ne,
200+ rewriter.create <arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
201+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
202+
203+ // We take the fallback of transfer_read default lowering only it is both
204+ // out-of-bounds and not word aligned.
205+ Value ifCondition =
206+ rewriter.create <arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
174207
175208 auto thenBuilder = [&](OpBuilder &builder, Location loc) {
176- Value fill = builder.create <vector::SplatOp>(loc, unbroadcastedVectorType,
177- readOp.getPadding ());
178- Value load = builder.create <vector::LoadOp>(loc, unbroadcastedVectorType,
179- readOp.getSource (),
180- readOp.getIndices ());
181- Value res = builder.create <arith::SelectOp>(loc, unbroadcastedVectorType,
182- readOp.getMask (), load, fill);
183-
184- // Insert a broadcasting op if required.
185- if (requiresBroadcasting) {
186- res = builder.create <vector::BroadcastOp>(loc, readOp.getVectorType (),
187- res);
188- }
189- rewriter.create <scf::YieldOp>(loc, res);
190- };
191-
192- auto elseBuilder = [&](OpBuilder &builder, Location loc) {
193209 Operation *read = builder.clone (*readOp.getOperation ());
194210 read->setAttr (" amdgpu.transformed" , builder.getUnitAttr ());
195211 Value readResult = read->getResult (0 );
196212 builder.create <scf::YieldOp>(loc, readResult);
197213 };
198214
215+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
216+ Value res = createVectorLoadForMaskedLoad (
217+ builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
218+ rewriter.create <scf::YieldOp>(loc, res);
219+ };
220+
199221 auto ifOp =
200- rewriter.create <scf::IfOp>(loc, isInBounds , thenBuilder, elseBuilder);
222+ rewriter.create <scf::IfOp>(loc, ifCondition , thenBuilder, elseBuilder);
201223
202224 rewriter.replaceOp (readOp, ifOp);
203225
0 commit comments