@@ -16,6 +16,14 @@ namespace mlir::iree_compiler {
1616
1717namespace {
1818
19+ // / Flatten the given value ranges into a single vector of values.
20+ static SmallVector<Value> flattenValues (ArrayRef<ValueRange> values) {
21+ SmallVector<Value> result;
22+ for (const auto &vals : values)
23+ llvm::append_range (result, vals);
24+ return result;
25+ }
26+
1927// %1 = hal.tensor.import %0 : !hal.buffer_view -> tensor<4xf32>
2028// ->
2129// %1 = stream.tensor.import %0 : !hal.buffer_view ->
@@ -24,7 +32,7 @@ struct ConvertTensorImportOp
2432 : public AffinityOpConversionPattern<IREE::HAL::TensorImportOp> {
2533 using AffinityOpConversionPattern::AffinityOpConversionPattern;
2634 LogicalResult matchAndRewriteOnAffinity (
27- IREE::HAL::TensorImportOp op, OpAdaptor adaptor,
35+ IREE::HAL::TensorImportOp op, OneToNOpAdaptor adaptor,
2836 IREE::Stream::AffinityAttr executionAffinityAttr,
2937 ConversionPatternRewriter &rewriter) const override {
3038 auto sourceType = op.getSource ().getType ();
@@ -42,9 +50,9 @@ struct ConvertTensorImportOp
4250 // mistake and it's better to know of a shape mismatch than just buffer
4351 // byte length difference.
4452 if (auto tensorType = llvm::dyn_cast<RankedTensorType>(targetType)) {
45- if (failed (buildEncodingAssertions (op. getLoc (), adaptor. getSource (),
46- op.getNameAttr (), tensorType ,
47- op.getTargetDims (), rewriter))) {
53+ if (failed (buildEncodingAssertions (
54+ op. getLoc (), adaptor. getSource (). front (), op.getNameAttr (),
55+ tensorType, op.getTargetDims (), rewriter))) {
4856 return rewriter.notifyMatchFailure (op, " unsupported tensor type" );
4957 }
5058 }
@@ -55,11 +63,12 @@ struct ConvertTensorImportOp
5563 IREE::Stream::Lifetime::External);
5664 Value resultSize = rewriter.create <IREE::Stream::TensorSizeOfOp>(
5765 op.getLoc (), rewriter.getIndexType (),
58- TypeAttr::get (op.getTarget ().getType ()), adaptor. getTargetDims (),
59- executionAffinityAttr);
66+ TypeAttr::get (op.getTarget ().getType ()),
67+ flattenValues (adaptor. getTargetDims ()), executionAffinityAttr);
6068 Value resource = rewriter.create <IREE::Stream::TensorImportOp>(
61- op.getLoc (), resultType, adaptor.getSource (), TypeAttr::get (targetType),
62- adaptor.getTargetDims (), resultSize, executionAffinityAttr);
69+ op.getLoc (), resultType, adaptor.getSource ().front (),
70+ TypeAttr::get (targetType), flattenValues (adaptor.getTargetDims ()),
71+ resultSize, executionAffinityAttr);
6372
6473 // Await the fence, if needed. When not specified the resource is assumed to
6574 // be immediately available.
@@ -75,10 +84,11 @@ struct ConvertTensorImportOp
7584 }
7685
7786 auto unknownType = rewriter.getType <IREE::Stream::ResourceType>();
78- rewriter.replaceOpWithNewOp <IREE::Stream::AsyncTransferOp>(
79- op, unknownType, resource, resultSize, resultSize,
87+ Value newImport = rewriter.create <IREE::Stream::AsyncTransferOp>(
88+ op. getLoc () , unknownType, resource, resultSize, resultSize,
8089 /* source_affinity=*/ executionAffinityAttr,
8190 /* target_affinity=*/ executionAffinityAttr);
91+ rewriter.replaceOpWithMultiple (op, {{newImport, resultSize}});
8292 return success ();
8393 }
8494
@@ -125,7 +135,7 @@ struct ConvertTensorExportOp
125135 : public AffinityOpConversionPattern<IREE::HAL::TensorExportOp> {
126136 using AffinityOpConversionPattern::AffinityOpConversionPattern;
127137 LogicalResult matchAndRewriteOnAffinity (
128- IREE::HAL::TensorExportOp op, OpAdaptor adaptor,
138+ IREE::HAL::TensorExportOp op, OneToNOpAdaptor adaptor,
129139 IREE::Stream::AffinityAttr executionAffinityAttr,
130140 ConversionPatternRewriter &rewriter) const override {
131141 auto sourceType = op.getSourceEncoding ();
@@ -136,12 +146,12 @@ struct ConvertTensorExportOp
136146 }
137147
138148 auto source =
139- transferTensorOperand (op.getLoc (), op.getSource (), adaptor.getSource (),
140- executionAffinityAttr, rewriter);
149+ transferTensorOperands (op.getLoc (), op.getSource (), adaptor.getSource (),
150+ executionAffinityAttr, rewriter);
141151
142152 // Exporting a produced value - transfer our source value to an externally
143153 // usable resource and directly export it. This will cause an allocation.
144- auto exportSource = adaptor.getSource ();
154+ Value exportSource = adaptor.getSource (). front ();
145155 auto externalType = rewriter.getType <IREE::Stream::ResourceType>(
146156 IREE::Stream::Lifetime::External);
147157 if (source.resource .getType () != externalType) {
@@ -154,7 +164,8 @@ struct ConvertTensorExportOp
154164 // Export (stream resource to buffer view).
155165 rewriter.replaceOpWithNewOp <IREE::Stream::TensorExportOp>(
156166 op, targetType, exportSource, TypeAttr::get (sourceType),
157- adaptor.getSourceDims (), source.resourceSize , executionAffinityAttr);
167+ flattenValues (adaptor.getSourceDims ()), source.resourceSize ,
168+ executionAffinityAttr);
158169 return success ();
159170 }
160171};
@@ -174,19 +185,21 @@ struct ConvertTensorAliasOp
174185 : public AffinityOpConversionPattern<IREE::HAL::TensorAliasOp> {
175186 using AffinityOpConversionPattern::AffinityOpConversionPattern;
176187 LogicalResult matchAndRewriteOnAffinity (
177- IREE::HAL::TensorAliasOp op, OpAdaptor adaptor,
188+ IREE::HAL::TensorAliasOp op, OneToNOpAdaptor adaptor,
178189 IREE::Stream::AffinityAttr executionAffinityAttr,
179190 ConversionPatternRewriter &rewriter) const override {
180191 auto sourceType = op.getSource ().getType ();
181192 auto source =
182- transferTensorOperand (op.getLoc (), op.getSource (), adaptor.getSource (),
183- executionAffinityAttr, rewriter);
193+ transferTensorOperands (op.getLoc (), op.getSource (), adaptor.getSource (),
194+ executionAffinityAttr, rewriter);
184195
185196 // Query the target storage buffer length; we will only populate up to
186197 // what is required for the output.
198+ SmallVector<Value> convertedSourceDims =
199+ flattenValues (adaptor.getSourceDims ());
187200 Value storageSize = rewriter.create <IREE::Stream::TensorSizeOfOp>(
188201 op.getLoc (), rewriter.getIndexType (),
189- TypeAttr::get (op.getSource ().getType ()), adaptor. getSourceDims () ,
202+ TypeAttr::get (op.getSource ().getType ()), convertedSourceDims ,
190203 executionAffinityAttr);
191204
192205 // Import the target storage as a resource that we can use as an update
@@ -195,8 +208,8 @@ struct ConvertTensorAliasOp
195208 auto externalType = rewriter.getType <IREE::Stream::ResourceType>(
196209 IREE::Stream::Lifetime::External);
197210 auto importOp = rewriter.create <IREE::Stream::TensorImportOp>(
198- op.getLoc (), externalType, adaptor.getStorage (),
199- TypeAttr::get (sourceType), adaptor. getSourceDims () , storageSize,
211+ op.getLoc (), externalType, adaptor.getStorage (). front () ,
212+ TypeAttr::get (sourceType), convertedSourceDims , storageSize,
200213 executionAffinityAttr);
201214
202215 // Await the fence, if needed. When not specified the storage is assumed to
@@ -235,7 +248,7 @@ struct ConvertTensorAliasOp
235248 op.getLoc (), source.resource .getType (), result, source.resourceSize ,
236249 source.resourceSize , executionAffinityAttr, executionAffinityAttr);
237250 }
238- rewriter.replaceOp (op, result);
251+ rewriter.replaceOpWithMultiple (op, {{ result, source. resourceSize }} );
239252
240253 return success ();
241254 }
@@ -254,20 +267,22 @@ struct ConvertTensorBarrierOp
254267 : public AffinityAwareConversionPattern<IREE::HAL::TensorBarrierOp> {
255268 using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
256269 LogicalResult
257- matchAndRewrite (IREE::HAL::TensorBarrierOp op, OpAdaptor adaptor,
270+ matchAndRewrite (IREE::HAL::TensorBarrierOp op, OneToNOpAdaptor adaptor,
258271 ConversionPatternRewriter &rewriter) const override {
259272 auto timepointType = rewriter.getType <IREE::Stream::TimepointType>();
260273 IREE::Stream::AffinityAttr anyAffinityAttr;
261274 SmallVector<Value> signaledResources;
275+ SmallVector<Value> signaledResourceSizes;
262276 SmallVector<Value> signaledTimepoints;
263277 for (auto [sourceTensor, sourceResource] :
264278 llvm::zip_equal (op.getSources (), adaptor.getSources ())) {
265- auto source = resolveTensorOperand (op.getLoc (), sourceTensor,
266- sourceResource, rewriter);
279+ auto source = resolveTensorOperands (op.getLoc (), sourceTensor,
280+ sourceResource, rewriter);
267281 auto barrierOp = rewriter.create <IREE::Stream::TimepointBarrierOp>(
268- sourceResource.getLoc (), source.resource .getType (), timepointType ,
269- source.resource , source.resourceSize , source.affinity );
282+ sourceResource.front (). getLoc (), source.resource .getType (),
283+ timepointType, source.resource , source.resourceSize , source.affinity );
270284 signaledResources.push_back (barrierOp.getResult ());
285+ signaledResourceSizes.push_back (source.resourceSize );
271286 signaledTimepoints.push_back (barrierOp.getResultTimepoint ());
272287
273288 // When joining from multiple affinities we need to pick one to perform
@@ -283,7 +298,8 @@ struct ConvertTensorBarrierOp
283298 rewriter.create <IREE::Stream::TimepointChainExternalOp>(
284299 op.getLoc (), joinedTimepoint, ValueRange{adaptor.getSignalFence ()},
285300 anyAffinityAttr);
286- rewriter.replaceOp (op, signaledResources);
301+ replaceOpWithMultiple (op, signaledResources, signaledResourceSizes,
302+ rewriter);
287303 return success ();
288304 }
289305};
0 commit comments