1313#include " iree/compiler/Dialect/Stream/IR/StreamOps.h"
1414#include " mlir/Dialect/Arith/IR/Arith.h"
1515#include " mlir/Dialect/Tensor/IR/Tensor.h"
16+ #include " mlir/IR/BuiltinDialect.h"
1617#include " mlir/IR/IRMapping.h"
1718#include " mlir/Interfaces/FunctionInterfaces.h"
1819
1920namespace mlir ::iree_compiler {
2021
2122namespace {
2223
24+ static SmallVector<Value> flattenValues (ArrayRef<ValueRange> values) {
25+ SmallVector<Value> vec;
26+ for (auto v : values) {
27+ vec.append (v.begin (), v.end ());
28+ }
29+ return vec;
30+ }
31+
2332// Inserts a sizeof calculation for the given tensor value type and dims.
2433// This should only be used to produce sizes for values produced by an op; the
2534// size of operands must be queried from the input resource.
@@ -142,6 +151,33 @@ struct ConvertTensorCastLikeOp
142151 }
143152};
144153
154+ template <typename CastOpTy>
155+ struct ConvertOneToNTensorCastLikeOp
156+ : public AffinityAwareConversionPattern<CastOpTy> {
157+ using AffinityAwareConversionPattern<
158+ CastOpTy>::AffinityAwareConversionPattern;
159+ LogicalResult matchAndRewrite (
160+ CastOpTy op,
161+ typename OpConversionPattern<CastOpTy>::OneToNOpAdaptor adaptor,
162+ ConversionPatternRewriter &rewriter) const override {
163+ auto resultAffinityAttr = this ->lookupResultAffinity (op.getResult ());
164+ Value convertedSource =
165+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getSource ());
166+ auto source = this ->transferTensorOperand (op.getLoc (), op.getSource (),
167+ convertedSource,
168+ resultAffinityAttr, rewriter);
169+ auto resultSize =
170+ buildResultSizeOf (op.getLoc (), op.getResult (), op.getResultDims (),
171+ resultAffinityAttr, rewriter);
172+ auto unknownType = rewriter.getType <IREE::Stream::ResourceType>();
173+ rewriter.replaceOpWithNewOp <IREE::Stream::TensorCloneOp>(
174+ op, unknownType, source.resource , op.getSource ().getType (),
175+ op.getSourceDims (), source.resourceSize , op.getResult ().getType (),
176+ flattenValues (adaptor.getResultDims ()), resultSize, resultAffinityAttr);
177+ return success ();
178+ }
179+ };
180+
145181struct ConvertTensorAllocaOp
146182 : public AffinityOpConversionPattern<IREE::Flow::TensorAllocaOp> {
147183 using AffinityOpConversionPattern::AffinityOpConversionPattern;
@@ -237,46 +273,55 @@ struct ConvertTensorTransferOp
237273};
238274
239275struct ConvertTensorSliceOp
240- : public AffinityOpConversionPattern <IREE::Flow::TensorSliceOp> {
241- using AffinityOpConversionPattern::AffinityOpConversionPattern ;
276+ : public AffinityOneToNOpConversionPattern <IREE::Flow::TensorSliceOp> {
277+ using AffinityOneToNOpConversionPattern::AffinityOneToNOpConversionPattern ;
242278 LogicalResult matchAndRewriteOnAffinity (
243- IREE::Flow::TensorSliceOp op, OpAdaptor adaptor,
279+ IREE::Flow::TensorSliceOp op, OneToNOpAdaptor adaptor,
244280 IREE::Stream::AffinityAttr executionAffinityAttr,
245281 ConversionPatternRewriter &rewriter) const override {
282+ Value convertedSource =
283+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getSource ());
246284 auto source =
247- transferTensorOperand (op.getLoc (), op.getSource (), adaptor. getSource () ,
285+ transferTensorOperand (op.getLoc (), op.getSource (), convertedSource ,
248286 executionAffinityAttr, rewriter);
249287 auto resultSize =
250288 buildResultSizeOf (op.getLoc (), op.getResult (), op.getResultDims (),
251289 executionAffinityAttr, rewriter);
252290 auto unknownType = rewriter.getType <IREE::Stream::ResourceType>();
253291 rewriter.replaceOpWithNewOp <IREE::Stream::TensorSliceOp>(
254292 op, unknownType, source.resource , op.getSource ().getType (),
255- op.getSourceDims (), source.resourceSize , adaptor.getStartIndices (),
256- adaptor.getLengths (), op.getResult ().getType (), adaptor.getResultDims (),
257- resultSize, executionAffinityAttr);
293+ op.getSourceDims (), source.resourceSize ,
294+ flattenValues (adaptor.getStartIndices ()),
295+ flattenValues (adaptor.getLengths ()), op.getResult ().getType (),
296+ flattenValues (adaptor.getResultDims ()), resultSize,
297+ executionAffinityAttr);
258298 return success ();
259299 }
260300};
261301
262302struct ConvertTensorUpdateOp
263- : public AffinityOpConversionPattern <IREE::Flow::TensorUpdateOp> {
264- using AffinityOpConversionPattern::AffinityOpConversionPattern ;
303+ : public AffinityOneToNOpConversionPattern <IREE::Flow::TensorUpdateOp> {
304+ using AffinityOneToNOpConversionPattern::AffinityOneToNOpConversionPattern ;
265305 LogicalResult matchAndRewriteOnAffinity (
266- IREE::Flow::TensorUpdateOp op, OpAdaptor adaptor,
306+ IREE::Flow::TensorUpdateOp op, OneToNOpAdaptor adaptor,
267307 IREE::Stream::AffinityAttr executionAffinityAttr,
268308 ConversionPatternRewriter &rewriter) const override {
309+ Value convertedTarget =
310+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getTarget ());
269311 auto target =
270- transferTensorOperand (op.getLoc (), op.getTarget (), adaptor. getTarget () ,
312+ transferTensorOperand (op.getLoc (), op.getTarget (), convertedTarget ,
271313 executionAffinityAttr, rewriter);
314+ Value convertedUpdate =
315+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getUpdate ());
272316 auto update =
273- transferTensorOperand (op.getLoc (), op.getUpdate (), adaptor. getUpdate () ,
317+ transferTensorOperand (op.getLoc (), op.getUpdate (), convertedUpdate ,
274318 executionAffinityAttr, rewriter);
275319 rewriter.replaceOpWithNewOp <IREE::Stream::TensorUpdateOp>(
276320 op, target.resource .getType (), target.resource ,
277- op.getTarget ().getType (), adaptor.getTargetDims (), target.resourceSize ,
278- adaptor.getStartIndices (), update.resource , op.getUpdate ().getType (),
279- op.getUpdateDims (), update.resourceSize , executionAffinityAttr);
321+ op.getTarget ().getType (), flattenValues (adaptor.getTargetDims ()),
322+ target.resourceSize , flattenValues (adaptor.getStartIndices ()),
323+ update.resource , op.getUpdate ().getType (), op.getUpdateDims (),
324+ update.resourceSize , executionAffinityAttr);
280325 return success ();
281326 }
282327};
@@ -296,10 +341,12 @@ struct ConvertTensorLoadOp
296341 : public AffinityAwareConversionPattern<IREE::Flow::TensorLoadOp> {
297342 using AffinityAwareConversionPattern::AffinityAwareConversionPattern;
298343 LogicalResult
299- matchAndRewrite (IREE::Flow::TensorLoadOp op, OpAdaptor adaptor,
344+ matchAndRewrite (IREE::Flow::TensorLoadOp op, OneToNOpAdaptor adaptor,
300345 ConversionPatternRewriter &rewriter) const override {
346+ Value convertedSource =
347+ getStreamResourceFromOneToNOpOperandAdaptor (adaptor.getSource ());
301348 auto source = resolveTensorOperand (op.getLoc (), op.getSource (),
302- adaptor. getSource () , rewriter);
349+ convertedSource , rewriter);
303350
304351 // If the source is not a staging resource then we need to transfer it to
305352 // a staging resource. We slice out just what is being loaded so that we
@@ -311,10 +358,13 @@ struct ConvertTensorLoadOp
311358 auto stagingType = rewriter.getType <IREE::Stream::ResourceType>(
312359 IREE::Stream::Lifetime::Staging);
313360 auto resultType = getTypeConverter ()->convertType (op.getResult ().getType ());
361+ SmallVector<Value> convertedSourceDims =
362+ flattenValues (adaptor.getSourceDims ());
363+ SmallVector<Value> convertedIndices = flattenValues (adaptor.getIndices ());
314364 if (source.resource .getType () == stagingType) {
315365 rewriter.replaceOpWithNewOp <IREE::Stream::TensorLoadOp>(
316366 op, resultType, source.resource , op.getSource ().getType (),
317- adaptor. getSourceDims () , source.resourceSize , adaptor. getIndices () );
367+ convertedSourceDims , source.resourceSize , convertedIndices );
318368 return success ();
319369 }
320370
@@ -328,19 +378,18 @@ struct ConvertTensorLoadOp
328378 /* result_affinity=*/ source.affinity );
329379 rewriter.replaceOpWithNewOp <IREE::Stream::TensorLoadOp>(
330380 op, resultType, transferOp.getResult (), sourceEncoding,
331- adaptor.getSourceDims (), transferOp.getResultSize (),
332- adaptor.getIndices ());
381+ convertedSourceDims, transferOp.getResultSize (), convertedIndices);
333382 return success ();
334383 }
335384
336385 // Slice out the individual element value.
337386 IndexSet indexSet (op.getLoc (), rewriter);
338- indexSet.populate (adaptor. getIndices () );
387+ indexSet.populate (convertedIndices );
339388 SmallVector<Value> sliceIndices;
340389 SmallVector<Value> sliceLengths;
341390 SmallVector<Value> loadIndices;
342391 SmallVector<int64_t > resultDims;
343- for (auto index : adaptor. getIndices () ) {
392+ for (auto index : convertedIndices ) {
344393 // TODO(benvanik): support larger buffer slices.
345394 sliceIndices.push_back (index);
346395 sliceLengths.push_back (indexSet.get (1 ));
@@ -354,9 +403,8 @@ struct ConvertTensorLoadOp
354403 op.getLoc (), resultEncoding, ValueRange{}, source.affinity );
355404 auto sliceOp = rewriter.create <IREE::Stream::TensorSliceOp>(
356405 op.getLoc (), source.resource .getType (), source.resource , sourceEncoding,
357- adaptor.getSourceDims (), source.resourceSize , sliceIndices,
358- sliceLengths, resultEncoding, ValueRange{}, resultSize,
359- source.affinity );
406+ convertedSourceDims, source.resourceSize , sliceIndices, sliceLengths,
407+ resultEncoding, ValueRange{}, resultSize, source.affinity );
360408 auto transferOp = rewriter.create <IREE::Stream::AsyncTransferOp>(
361409 op.getLoc (), stagingType, sliceOp.getResult (), sliceOp.getResultSize (),
362410 sliceOp.getResultSize (),
@@ -713,10 +761,10 @@ struct ConvertCollectiveSendRecvOp
713761};
714762
715763struct ConvertDispatchOp
716- : public AffinityOpConversionPattern <IREE::Flow::DispatchOp> {
717- using AffinityOpConversionPattern::AffinityOpConversionPattern ;
764+ : public AffinityOneToNOpConversionPattern <IREE::Flow::DispatchOp> {
765+ using AffinityOneToNOpConversionPattern::AffinityOneToNOpConversionPattern ;
718766 LogicalResult matchAndRewriteOnAffinity (
719- IREE::Flow::DispatchOp op, OpAdaptor adaptor,
767+ IREE::Flow::DispatchOp op, OneToNOpAdaptor adaptor,
720768 IREE::Stream::AffinityAttr executionAffinityAttr,
721769 ConversionPatternRewriter &rewriter) const override {
722770 // Zero is going to be used for each operand to start.
@@ -729,8 +777,11 @@ struct ConvertDispatchOp
729777 SmallVector<Value> dispatchOperandEnds;
730778 SmallVector<Value> dispatchOperandLengths;
731779 SmallVector<Value> operandSizes;
780+
781+ SmallVector<Value> convertedArguments =
782+ getStreamResourcesFromOneToNOpOperandAdaptors (adaptor.getArguments ());
732783 for (auto [oldOperand, newOperand] :
733- llvm::zip_equal (op.getArguments (), adaptor. getArguments () )) {
784+ llvm::zip_equal (op.getArguments (), convertedArguments )) {
734785 if (llvm::isa<ShapedType>(oldOperand.getType ())) {
735786 auto newOperandCast =
736787 transferTensorOperand (op.getLoc (), oldOperand, newOperand,
@@ -774,10 +825,10 @@ struct ConvertDispatchOp
774825 }
775826
776827 auto newOp = rewriter.replaceOpWithNewOp <IREE::Stream::AsyncDispatchOp>(
777- op, resultTypes, adaptor.getWorkload (), adaptor. getEntryPointsAttr ( ),
778- dispatchOperands, dispatchOperandSizes, dispatchOperandOffsets ,
779- dispatchOperandEnds, dispatchOperandLengths, resultSizes ,
780- adaptor.getTiedOperandsAttr (), executionAffinityAttr);
828+ op, resultTypes, flattenValues ( adaptor.getWorkload ()),
829+ adaptor. getEntryPointsAttr (), dispatchOperands, dispatchOperandSizes ,
830+ dispatchOperandOffsets, dispatchOperandEnds, dispatchOperandLengths ,
831+ resultSizes, adaptor.getTiedOperandsAttr (), executionAffinityAttr);
781832 newOp->setDialectAttrs (op->getDialectAttrs ());
782833 return success ();
783834 }
@@ -1105,8 +1156,8 @@ void populateFlowToStreamConversionPatterns(
11051156 RewritePatternSet &patterns) {
11061157 patterns
11071158 .insert <ConvertTensorConstantOp, ConvertTensorDynamicConstantOp,
1108- ConvertTensorCastLikeOp <IREE::Flow::TensorReshapeOp>,
1109- ConvertTensorCastLikeOp <IREE::Flow::TensorBitCastOp>,
1159+ ConvertOneToNTensorCastLikeOp <IREE::Flow::TensorReshapeOp>,
1160+ ConvertOneToNTensorCastLikeOp <IREE::Flow::TensorBitCastOp>,
11101161 ConvertTensorAllocaOp, ConvertTensorEmptyOp, ConvertTensorSplatOp,
11111162 ConvertTensorCloneOp, ConvertTensorTransferOp,
11121163 ConvertTensorSliceOp, ConvertTensorUpdateOp, ConvertTensorLoadOp,
0 commit comments