@@ -90,6 +90,9 @@ struct DecomposeProjectedPermutation : public OpRewritePattern<GenericOp> {
9090std::pair<SmallVector<int64_t >, SmallVector<int64_t >>
9191computeTransposeBroadcast (AffineMap &map) {
9292 assert (map.isProjectedPermutation (false ) && " not a projection" );
93+
94+ // As the map is a projection it likely operates on a smaller set of
95+ // dimensions as far as the transpose is concerned (rest are broadcast).
9396 int64_t minorSize = map.getNumResults ();
9497
9598 SmallVector<int64_t > minorResult;
@@ -116,13 +119,13 @@ computeTransposeBroadcast(AffineMap &map) {
116119
117120 SmallVector<int64_t > permutation;
118121 if (hasTranspose) {
119- // / Consider an operand `x : tensor<7x8x9>` of a genericOp that has
120- // / affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
121- // / `x`s access is both transposed and brodcast . But when specifying
122- // / the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
123- // / specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
124- // / refering to d3, d4. Therefore, re-base the transpose dimensions so
125- // / that they start from d0.
122+ // Consider an operand `x : tensor<7x8x9>` of a genericOp that has
123+ // affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
124+ // `x`s access is both transposed and broadcast . But when specifying
125+ // the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
126+ // specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
127+ // refering to d3, d4. Therefore, re-base the transpose dimensions so
128+ // that they start from d0.
126129 permutation.resize (minorSize);
127130 std::map<int64_t , int64_t > minorMap;
128131 for (int64_t i = 0 ; i < minorSize; ++i)
@@ -147,14 +150,19 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
147150 op.isSingleYieldOp () || !op.isAllParallelLoops ())
148151 return failure ();
149152
150- // All maps need to be projected permutations.
153+ // If the map of an operand is not a `projected permutation` then
154+ // it cannot be decomposed to mere transpose and broadcast.
155+ // The requirement that all maps be `projected permutation` may be
156+ // over-restrictive but since we need to determine shape of the
157+ // iteration space as well, reject if any map violates assumption.
151158 for (auto &opOperand : op->getOpOperands ()) {
152159 auto map = op.getMatchingIndexingMap (&opOperand);
153160 if (!map.isProjectedPermutation (false ))
154161 return failure ();
155162 }
156163
157- // Currently we handle only static shapes.
164+ // Decomposing linalg.generic involves creating `tensor.empty`
165+ // which cannot have dnyamic shapes.
158166 for (auto &operand : op->getOpOperands ()) {
159167 auto opType = cast<RankedTensorType>(operand.get ().getType ());
160168 for (auto size : opType.getShape ())
0 commit comments