@@ -566,45 +566,6 @@ static RankedTensorType getExpandedType(RankedTensorType originalType,
566
566
return RankedTensorType::get (expandedShape, originalType.getElementType ());
567
567
}
568
568
569
- // / Get the value to use for the output in the expanded operation given the
570
- // / `indexingMap` for the output in the original op. Creates an
571
- // / `linalg.init_tensor` operation to materialize the tensor that carries the
572
- // / shape information. This is only used when the tensor_reshape is expanding
573
- // / and is a consumer. In such cases, the tensor_reshape op semantics gaurantees
574
- // / that the shape of the output is computable from the shape of the input since
575
- // / at most one of the expanded dims can be dynamic.
576
- static Value getOutputValueForExpandedOp (OpBuilder &builder, Location loc,
577
- AffineMap indexingMap, Value result,
578
- const ExpansionInfo &expansionInfo) {
579
- SmallVector<Value, 4 > dynamicDims;
580
- SmallVector<int64_t , 4 > staticDims;
581
- ShapedType resultType = result.getType ().cast <ShapedType>();
582
- ArrayRef<int64_t > origShape = resultType.getShape ();
583
- for (AffineExpr expr : indexingMap.getResults ()) {
584
- unsigned origDimPos = expr.cast <AffineDimExpr>().getPosition ();
585
- bool foundDynamic = false ;
586
- int64_t linearizedShape = 1 ;
587
- for (int64_t extent : expansionInfo.getExpandedShapeOfDim (origDimPos)) {
588
- if (ShapedType::isDynamic (extent)) {
589
- assert (!foundDynamic &&
590
- " Expanded dimensions of reshape can have only one dynamic dim" );
591
- staticDims.push_back (ShapedType::kDynamicSize );
592
- foundDynamic = true ;
593
- continue ;
594
- }
595
- staticDims.push_back (extent);
596
- linearizedShape *= extent;
597
- }
598
- if (ShapedType::isDynamic (origShape[origDimPos])) {
599
- Value origDim = builder.create <DimOp>(loc, result, origDimPos);
600
- dynamicDims.push_back (builder.create <UnsignedDivIOp>(
601
- loc, origDim, builder.create <ConstantIndexOp>(loc, linearizedShape)));
602
- }
603
- }
604
- return builder.create <linalg::InitTensorOp>(loc, dynamicDims, staticDims,
605
- resultType.getElementType ());
606
- }
607
-
608
569
// / Returns the reassociation maps to use in the `linalg.tensor_reshape`
609
570
// / operation to convert the operands of the origial operation to operands of
610
571
// / the expanded operation. The same method is used to compute the
@@ -734,8 +695,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
734
695
SmallVector<Value, 1 > outputs;
735
696
for (auto result : llvm::enumerate (linalgOp.getOutputs ())) {
736
697
AffineMap indexingMap = linalgOp.getOutputIndexingMap (result.index ());
737
- outputs.push_back (getOutputValueForExpandedOp (
738
- rewriter, loc, indexingMap, result.value (), expansionInfo));
698
+ RankedTensorType expandedOutputType =
699
+ getExpandedType (result.value ().getType ().cast <RankedTensorType>(),
700
+ indexingMap, expansionInfo);
701
+ if (expandedOutputType != result.value ().getType ()) {
702
+ SmallVector<ReassociationIndices, 4 > reassociation =
703
+ getReassociationForExpansion (indexingMap, expansionInfo);
704
+ outputs.push_back (rewriter.create <TensorReshapeOp>(
705
+ linalgOp.getLoc (), expandedOutputType, result.value (),
706
+ reassociation));
707
+ }
739
708
}
740
709
741
710
// The iterator types of the expanded op are all parallel.
@@ -779,47 +748,6 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
779
748
return resultVals;
780
749
}
781
750
782
- static Value
783
- getOutputValueForLinearization (OpBuilder &builder, Location loc,
784
- Value origOutput,
785
- ArrayRef<AffineMap> reassociationMaps) {
786
- SmallVector<Value, 4 > dynamicDims;
787
- SmallVector<int64_t , 4 > staticDims;
788
- auto shapedType = origOutput.getType ().cast <ShapedType>();
789
- ArrayRef<int64_t > origShape = shapedType.getShape ();
790
- for (auto map : reassociationMaps) {
791
- Optional<Value> dynamicDim;
792
- int64_t staticLinearizedShape = 1 ;
793
- for (AffineDimExpr expr :
794
- llvm::map_range (map.getResults (), [](AffineExpr e) {
795
- return e.cast <AffineDimExpr>();
796
- })) {
797
- unsigned pos = expr.getPosition ();
798
- if (ShapedType::isDynamic (origShape[pos])) {
799
- Value dim = builder.create <DimOp>(loc, origOutput, pos);
800
- if (dynamicDim) {
801
- dynamicDim = builder.create <MulIOp>(loc, dynamicDim.getValue (), dim);
802
- } else {
803
- dynamicDim = dim;
804
- }
805
- } else {
806
- staticLinearizedShape *= origShape[pos];
807
- }
808
- }
809
- if (dynamicDim) {
810
- dynamicDim = builder.create <MulIOp>(
811
- loc, dynamicDim.getValue (),
812
- builder.create <ConstantIndexOp>(loc, staticLinearizedShape));
813
- dynamicDims.push_back (dynamicDim.getValue ());
814
- staticDims.push_back (ShapedType::kDynamicSize );
815
- } else {
816
- staticDims.push_back (staticLinearizedShape);
817
- }
818
- }
819
- return builder.create <InitTensorOp>(loc, dynamicDims, staticDims,
820
- shapedType.getElementType ());
821
- }
822
-
823
751
namespace {
824
752
825
753
// / Pattern to fold tensor_reshape op with its consumer by using the source of
@@ -973,7 +901,7 @@ struct FoldConsumerReshapeOpByLinearization
973
901
reshapeOp.getReassociationMaps ());
974
902
for (AffineExpr expr : modifiedMap.getResults ()) {
975
903
if (!expr.isPureAffine ())
976
- return reshapeOp .emitRemark (" fused op indexing map is not affine" );
904
+ return producer .emitRemark (" fused op indexing map is not affine" );
977
905
}
978
906
fusedIndexMaps.back () = modifiedMap;
979
907
@@ -983,9 +911,8 @@ struct FoldConsumerReshapeOpByLinearization
983
911
return reshapeOp.emitRemark (" fused op loop bound computation failed" );
984
912
985
913
Location loc = producer.getLoc ();
986
- Value output =
987
- getOutputValueForLinearization (rewriter, loc, producer.getOutputs ()[0 ],
988
- reshapeOp.getReassociationMaps ());
914
+ Value output = rewriter.create <TensorReshapeOp>(
915
+ loc, producer.getOutputs ()[0 ], reshapeOp.getReassociationExprs ());
989
916
LinalgOp fusedOp = createLinalgOpOfSameType (
990
917
producer, rewriter, loc, reshapeOp.getResultType (),
991
918
/* inputs=*/ producer.getInputs (),
0 commit comments