@@ -811,24 +811,12 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
811811}
812812
813813// Create an expanded transpose op.
814- // For sinking a collapse : transpose(collapse_shape),
815- // all expanded groups are permuted together. We just permute the reassocation
816- // map of the collapse and flatten it. For example,
817- //
818- // reassociation_map = [[0], [1, 2, 3], [4, 5]]
819- // permutation = [2, 0, 1]
820- //
821- // Becomes
822- //
823- // permutation = [4, 5, 0 , 1, 2, 3]
824- //
825- // For bubbling an expand : expand_shape(transpose),
826814// the reassociation map is already permuted hence we inverse permute and then
827815// flatten it. Then we inverse permute it again to get the final expanded
828816// transpose permutation. For example,
829817//
830818// permutation = [2, 0, 1]
831- // reassociation_map = [[0, 1], [2], [3, 4, 5]]
819+ // reassociation_map for expansion = [[0, 1], [2], [3, 4, 5]]
832820//
833821// inverse permutation = [1, 2, 0]
834822// applied to reassocation_map and then flattened becomes
@@ -839,25 +827,19 @@ validateDynamicDimExpansion(LinalgOp linalgOp,
839827//
840828// permutation=[4, 5, 0, 1, 2, 3]
841829
842- static Operation *
843- createExpandedTransposeOp (PatternRewriter &rewriter, TransposeOp transposeOp,
844- SmallVector<ReassociationIndices> reassociation,
845- Value expandedInput, Value output, bool isExpanding) {
846- ArrayRef<int64_t > permutation =
847- isExpanding ? invertPermutationVector (transposeOp.getPermutation ())
848- : transposeOp.getPermutation ();
849- applyPermutationToVector (reassociation, permutation);
830+ static Operation *createExpandedTransposeOp (PatternRewriter &rewriter,
831+ TransposeOp transposeOp,
832+ Value expandedInput, Value output,
833+ ExpansionInfo &expansionInfo) {
850834 SmallVector<int64_t > newPerm;
851- for (const auto &reassoc : reassociation) {
852- for (auto dim : reassoc) {
835+ for (int64_t perm : invertPermutationVector (transposeOp.getPermutation ())) {
836+ auto reassoc = expansionInfo.getExpandedDims (perm);
837+ for (int64_t dim : reassoc) {
853838 newPerm.push_back (dim);
854839 }
855840 }
856- if (isExpanding) {
857- newPerm = invertPermutationVector (newPerm);
858- }
859841 return rewriter.create <TransposeOp>(transposeOp.getLoc (), expandedInput,
860- output, newPerm);
842+ output, invertPermutationVector ( newPerm) );
861843}
862844
863845// Create an expanded generic op.
@@ -891,17 +873,18 @@ static Operation *createExpandedGenericOp(
891873// Create an expanded fused op that retains the name for certain ops
892874// such as fill, copy and transpose and produce a generic op for
893875// rest of linalg ops.
894- static Operation *createExpandedOp (
895- PatternRewriter &rewriter, LinalgOp linalgOp, TypeRange resultTypes,
896- ArrayRef<Value> expandedOpOperands, ArrayRef<Value> outputs,
897- ArrayRef<AffineMap> expandedOpIndexingMaps, ExpansionInfo &expansionInfo,
898- SmallVector<ReassociationIndices> reassociation, bool isExpanding) {
876+ static Operation *createExpandedOp (PatternRewriter &rewriter, LinalgOp linalgOp,
877+ TypeRange resultTypes,
878+ ArrayRef<Value> expandedOpOperands,
879+ ArrayRef<Value> outputs,
880+ ArrayRef<AffineMap> expandedOpIndexingMaps,
881+ ExpansionInfo &expansionInfo) {
899882
900883 return TypeSwitch<Operation *, Operation *>(linalgOp.getOperation ())
901884 .Case <TransposeOp>([&](TransposeOp transposeOp) {
902- return createExpandedTransposeOp (rewriter, transposeOp, reassociation,
885+ return createExpandedTransposeOp (rewriter, transposeOp,
903886 expandedOpOperands[0 ], outputs[0 ],
904- isExpanding );
887+ expansionInfo );
905888 })
906889 .Case <FillOp, CopyOp>([&](Operation *op) {
907890 return clone (rewriter, linalgOp, resultTypes,
@@ -1021,13 +1004,9 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
10211004 }
10221005
10231006 TypeRange resultTypes = ValueRange (outputs).getTypes ();
1024- SmallVector<ReassociationIndices> reassociationBeforeExpansion =
1025- isExpanding ? expandingReshapeOp.getReassociationIndices ()
1026- : collapsingReshapeOp.getReassociationIndices ();
10271007 Operation *fusedOp =
10281008 createExpandedOp (rewriter, linalgOp, resultTypes, expandedOpOperands,
1029- outputs, expandedOpIndexingMaps, expansionInfo,
1030- reassociationBeforeExpansion, isExpanding);
1009+ outputs, expandedOpIndexingMaps, expansionInfo);
10311010 // Reshape the result values to their original shape if this is a collapsing
10321011 // reshape folded into its consumer.
10331012 SmallVector<Value> resultVals;
0 commit comments