@@ -1143,15 +1143,18 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11431143 ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
11441144 auto outerDimsPerm = packOp.getOuterDimsPerm ();
11451145
1146- // Verify that there are no non-unit un-tiled outer dims that are permuted.
1147- // Supporting such cases will require refining the logic to generate the
1148- // Transpose Op.
1146+ // Verify that there are no:
1147+ // * non-unit + un-tiled-outer-dims,
1148+ // that are permuted. Supporting such cases would require refining the logic
1149+ // that generates the Transpose Op.
11491150 if (!llvm::all_of (outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
11501151 static int prev = 0 ;
1151- // Tiled dims are not relevant here .
1152+ // Skip tiled dims - these can be permuted .
11521153 if (llvm::is_contained (innerDimsPos, dim))
11531154 return true ;
1154- // Was this dim permuted? Note, permuting unit dims is fine.
1155+
1156+ // Check whether this dim has been permuted. Permuting unit dims is fine
1157+ // as that's effectively a no-op.
11551158 if (dim < prev && (packOp.getType ().getShape ()[prev] != 1 ||
11561159 packOp.getType ().getShape ()[dim] != 1 ))
11571160 return false ;
@@ -1182,8 +1185,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11821185 // Assumptions made:
11831186 // - All tiled outer dims are 1 - the corresponding transposition order
11841187 // doesn't matter, but requires all dim indices to be present.
1185- // - Un-tiled outer dims remain un-permuted. (TODO: Fail when this does not
1186- // hold)
1188+ // - Un-tiled outer dims remain un-permuted.
11871189
11881190 // 2.1 Get the permutation for linalg.transpose:
11891191 // [ untiled-dims, inner-dims-pos ]
@@ -1240,16 +1242,15 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12401242 auto transposedOp = linalg::TransposeOp::create (rewriter, loc, input, empty,
12411243 srcPermForTranspose);
12421244
1243- // 3. Insert the inner tile to the destination:
1245+ // 3. Insert the inner tile into the destination tensor :
12441246 // %inserted_tile = tensor.insert_slice(%transposed_tile)
12451247
12461248 // Compute the sizes attribute:
12471249 // [ outer-dims, tile-sizes ]
12481250 // Note that the output from the transpose Op excludes the tiled outer dims.
1249- // Given the assumptions (all tiled outer dims == 1), we can safely use a
1250- // rank-expanding tensor.insert_slice. Rather than manually computing where to
1251- // insert new unit dims (resulting from the expansion), use the Pack op
1252- // attributes.
1251+ // However, given the assumption that:
1252+ // * all tiled outer dims == 1,
1253+ // we can just use a rank-expanding tensor.insert_slice.
12531254 SmallVector<OpFoldResult> writeSizes;
12541255 for (auto size : packOp.getAllOuterDims ()) {
12551256 writeSizes.push_back (rewriter.getIndexAttr (size));
@@ -1261,10 +1262,11 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12611262 writeSizes.push_back (tileSizeOfr);
12621263 }
12631264
1265+ // TODO: Add a constructor for tensor.insert_slice that doesn't require
1266+ // strides nor offsets.
12641267 SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
12651268 SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
12661269
1267- // TODO: A constructor that doesn't require strides nor offsets.
12681270 auto insert = tensor::InsertSliceOp::create (
12691271 rewriter, loc, transposedOp.getResult ()[0 ], packOp.getDest (),
12701272 writeOffsets, writeSizes, writeStrides);
0 commit comments