Skip to content

Commit 2ba0e0a

Browse files
authored
[Codegen] Simplify tensor load/store padding materialization (#21160)
Towards collapsing the padding encoding materialization into the generic encoding materialization: #20160. I don't think the intermediate extract and insert slices are needed and removing them will streamline the implementation with the generic one, so that we can later combine them into a single implementation in follow-up work. Signed-off-by: Jorn Tuyls <[email protected]>
1 parent 1d2db1a commit 2ba0e0a

File tree

1 file changed

+7
-33
lines changed

1 file changed

+7
-33
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPadding.cpp

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ struct MaterializePadEncodingTypeConverter final
118118
};
119119

120120
/// Pattern to convert `iree_tensor_ext.dispatch.tensor.load` operation when
121-
/// materializing the encoding. We extract a smaller tensor for the padded
122-
/// source. This way we do not create partial loads prematurely, which would be
123-
/// difficult to undo later on.
121+
/// materializing the encoding.
124122
struct MaterializeFlowDispatchTensorLoadOp final
125123
: OpConversionPattern<IREE::TensorExt::DispatchTensorLoadOp> {
126124
using OpConversionPattern::OpConversionPattern;
@@ -155,28 +153,17 @@ struct MaterializeFlowDispatchTensorLoadOp final
155153
rewriter.getIndexAttr(0));
156154
SmallVector<OpFoldResult> newStrides(newMixedSizes.size(),
157155
rewriter.getIndexAttr(1));
158-
SmallVector<int64_t> newStaticDims;
159-
SmallVector<Value> newDynamicDims;
160-
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);
161-
162-
Location loc = loadOp.getLoc();
163-
Value newLoad = rewriter.create<IREE::TensorExt::DispatchTensorLoadOp>(
164-
loc, adaptor.getSource(), newDynamicDims, newOffsets, newMixedSizes,
165-
newStrides);
166-
auto extractType = RankedTensorType::get(boundTensorType.getShape(),
167-
boundTensorType.getElementType());
168156
SmallVector<OpFoldResult> extractSizes = getMixedValues(
169157
boundTensorType.getShape(), loadOp.getSourceDims(), rewriter);
170-
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
171-
loadOp, extractType, newLoad, newOffsets, extractSizes, newStrides);
158+
rewriter.replaceOpWithNewOp<IREE::TensorExt::DispatchTensorLoadOp>(
159+
loadOp, adaptor.getSource(), loadOp.getSourceDims(), newOffsets,
160+
extractSizes, newStrides);
172161
return success();
173162
}
174163
};
175164

176165
/// Pattern to convert `iree_tensor_ext.dispatch.tensor.store` operation when
177-
/// materializing the encoding. We create a larger empty tensor for the
178-
/// destination and insert the value into it. This way we do not create partial
179-
/// stores prematurely, which would be difficult to undo later on.
166+
/// materializing the encoding.
180167
struct MaterializeFlowDispatchTensorStoreOp final
181168
: OpConversionPattern<IREE::TensorExt::DispatchTensorStoreOp> {
182169
using OpConversionPattern::OpConversionPattern;
@@ -206,28 +193,15 @@ struct MaterializeFlowDispatchTensorStoreOp final
206193
RankedTensorType paddedType = newTargetType.asRankedTensorType();
207194

208195
Location loc = storeOp.getLoc();
209-
SmallVector<Value> dynamicResultSizes{adaptor.getOperands()};
210-
Value empty =
211-
rewriter.create<tensor::EmptyOp>(loc, paddedType, dynamicResultSizes);
212-
213196
SmallVector<OpFoldResult> offsets(paddedType.getRank(),
214197
rewriter.getIndexAttr(0));
215198
SmallVector<OpFoldResult> strides(paddedType.getRank(),
216199
rewriter.getIndexAttr(1));
217200
SmallVector<OpFoldResult> sizes =
218201
tensor::getMixedSizes(rewriter, loc, adaptor.getValue());
219-
Value insertOp = rewriter.create<tensor::InsertSliceOp>(
220-
loc, adaptor.getValue(), empty, offsets, sizes, strides);
221-
222-
SmallVector<OpFoldResult> newMixedSizes = getMixedValues(
223-
paddedType.getShape(), storeOp.getTargetDims(), rewriter);
224-
SmallVector<int64_t> newStaticDims;
225-
SmallVector<Value> newDynamicDims;
226-
dispatchIndexOpFoldResults(newMixedSizes, newDynamicDims, newStaticDims);
227-
228202
rewriter.replaceOpWithNewOp<IREE::TensorExt::DispatchTensorStoreOp>(
229-
storeOp, insertOp, adaptor.getTarget(), newDynamicDims, offsets,
230-
newMixedSizes, strides);
203+
storeOp, adaptor.getValue(), adaptor.getTarget(),
204+
adaptor.getTargetDims(), offsets, sizes, strides);
231205
return success();
232206
}
233207
};

0 commit comments

Comments
 (0)