@@ -118,9 +118,7 @@ struct MaterializePadEncodingTypeConverter final
118118};
119119
120120// / Pattern to convert `iree_tensor_ext.dispatch.tensor.load` operation when
121- // / materializing the encoding. We extract a smaller tensor for the padded
122- // / source. This way we do not create partial loads prematurely, which would be
123- // / difficult to undo later on.
121+ // / materializing the encoding.
124122struct MaterializeFlowDispatchTensorLoadOp final
125123 : OpConversionPattern<IREE::TensorExt::DispatchTensorLoadOp> {
126124 using OpConversionPattern::OpConversionPattern;
@@ -155,28 +153,17 @@ struct MaterializeFlowDispatchTensorLoadOp final
155153 rewriter.getIndexAttr (0 ));
156154 SmallVector<OpFoldResult> newStrides (newMixedSizes.size (),
157155 rewriter.getIndexAttr (1 ));
158- SmallVector<int64_t > newStaticDims;
159- SmallVector<Value> newDynamicDims;
160- dispatchIndexOpFoldResults (newMixedSizes, newDynamicDims, newStaticDims);
161-
162- Location loc = loadOp.getLoc ();
163- Value newLoad = rewriter.create <IREE::TensorExt::DispatchTensorLoadOp>(
164- loc, adaptor.getSource (), newDynamicDims, newOffsets, newMixedSizes,
165- newStrides);
166- auto extractType = RankedTensorType::get (boundTensorType.getShape (),
167- boundTensorType.getElementType ());
168156 SmallVector<OpFoldResult> extractSizes = getMixedValues (
169157 boundTensorType.getShape (), loadOp.getSourceDims (), rewriter);
170- rewriter.replaceOpWithNewOp <tensor::ExtractSliceOp>(
171- loadOp, extractType, newLoad, newOffsets, extractSizes, newStrides);
158+ rewriter.replaceOpWithNewOp <IREE::TensorExt::DispatchTensorLoadOp>(
159+ loadOp, adaptor.getSource (), loadOp.getSourceDims (), newOffsets,
160+ extractSizes, newStrides);
172161 return success ();
173162 }
174163};
175164
176165// / Pattern to convert `iree_tensor_ext.dispatch.tensor.store` operation when
177- // / materializing the encoding. We create a larger empty tensor for the
178- // / destination and insert the value into it. This way we do not create partial
179- // / stores prematurely, which would be difficult to undo later on.
166+ // / materializing the encoding.
180167struct MaterializeFlowDispatchTensorStoreOp final
181168 : OpConversionPattern<IREE::TensorExt::DispatchTensorStoreOp> {
182169 using OpConversionPattern::OpConversionPattern;
@@ -206,28 +193,15 @@ struct MaterializeFlowDispatchTensorStoreOp final
206193 RankedTensorType paddedType = newTargetType.asRankedTensorType ();
207194
208195 Location loc = storeOp.getLoc ();
209- SmallVector<Value> dynamicResultSizes{adaptor.getOperands ()};
210- Value empty =
211- rewriter.create <tensor::EmptyOp>(loc, paddedType, dynamicResultSizes);
212-
213196 SmallVector<OpFoldResult> offsets (paddedType.getRank (),
214197 rewriter.getIndexAttr (0 ));
215198 SmallVector<OpFoldResult> strides (paddedType.getRank (),
216199 rewriter.getIndexAttr (1 ));
217200 SmallVector<OpFoldResult> sizes =
218201 tensor::getMixedSizes (rewriter, loc, adaptor.getValue ());
219- Value insertOp = rewriter.create <tensor::InsertSliceOp>(
220- loc, adaptor.getValue (), empty, offsets, sizes, strides);
221-
222- SmallVector<OpFoldResult> newMixedSizes = getMixedValues (
223- paddedType.getShape (), storeOp.getTargetDims (), rewriter);
224- SmallVector<int64_t > newStaticDims;
225- SmallVector<Value> newDynamicDims;
226- dispatchIndexOpFoldResults (newMixedSizes, newDynamicDims, newStaticDims);
227-
228202 rewriter.replaceOpWithNewOp <IREE::TensorExt::DispatchTensorStoreOp>(
229- storeOp, insertOp , adaptor.getTarget (), newDynamicDims, offsets ,
230- newMixedSizes , strides);
203+ storeOp, adaptor. getValue () , adaptor.getTarget (),
204+ adaptor. getTargetDims (), offsets, sizes , strides);
231205 return success ();
232206 }
233207};
0 commit comments