Skip to content

Commit b19b649

Browse files
committed
Address comments
- rename collapseTo to better reflect its usage - assert it only collapse unit dimensions - rename ReassociationIndices-using variables to reassocGroup and reassocMaps, the same terminology used in tensor.collapse_shape documentation - use more representative test with comments to better explain what the patch does
1 parent ce2ef6a commit b19b649

File tree

2 files changed

+87
-54
lines changed

2 files changed

+87
-54
lines changed

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -236,37 +236,39 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
236236
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
237237
}
238238

239-
/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
239+
/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
240240
/// `from`.
241241
static tensor::CollapseShapeOp
242-
collapseTo(OpBuilder &b, Location loc, Value from,
243-
const llvm::SmallBitVector &dropDims) {
242+
dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
243+
const llvm::SmallBitVector &dropDims) {
244244
auto fromType = cast<ShapedType>(from.getType());
245-
assert(fromType.getRank() == dropDims.size() &&
245+
assert(fromType.getRank() == static_cast<int64_t>(dropDims.size()) &&
246246
"dropDims dimension does not match from tensor rank");
247247
// Computed reassociation map for the corresponding tensor.collapse_shape.
248-
SmallVector<ReassociationIndices, 2> reassocIdxsVec;
249-
// Current reassociation indices to add dropped dimension to.
250-
ReassociationIndices reassocIdxs;
248+
SmallVector<ReassociationIndices, 2> reassocMaps;
249+
// Current reassociation group to add dropped dimension to.
250+
ReassociationIndices reassocGroup;
251251

252252
bool foundKeptDim = false;
253253
// Dropped dimensions might be at the beginning or end of the shape so
254254
// combine all contiguous dimensions before and after a given non dropped
255-
// dimension in reassocIdxs until another non dropped dimension is found.
255+
// dimension in reassocGroup until another non dropped dimension is found.
256256
// When that happens, add the reassociation indices to the map.
257257
for (int dim = 0; dim < fromType.getRank(); dim++) {
258-
if (!dropDims.test(dim)) {
258+
if (dropDims.test(dim))
259+
assert(fromType.getShape()[dim] == 1 && "Dropping non unit dimension");
260+
else {
259261
if (foundKeptDim) {
260-
reassocIdxsVec.push_back(reassocIdxs);
261-
reassocIdxs.clear();
262+
reassocMaps.push_back(reassocGroup);
263+
reassocGroup.clear();
262264
}
263265
foundKeptDim = true;
264266
}
265-
reassocIdxs.push_back(dim);
267+
reassocGroup.push_back(dim);
266268
}
267-
if (!reassocIdxs.empty())
268-
reassocIdxsVec.push_back(reassocIdxs);
269-
return b.create<tensor::CollapseShapeOp>(loc, from, reassocIdxsVec);
269+
if (!reassocGroup.empty())
270+
reassocMaps.push_back(reassocGroup);
271+
return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
270272
}
271273

272274
FailureOr<FusionInfo>
@@ -312,7 +314,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
312314
// Rank-reduction occured as part of the extract_slice.
313315
if (cast<ShapedType>(consumerType).getRank() !=
314316
cast<ShapedType>(def.getType()).getRank())
315-
def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims);
317+
def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
316318
// Canonicalizations are not guaranteed to have happened before constructing
317319
// `fusedProducer`. In the tensor case this can result in temporary type
318320
// mismatches. Insert a `tensor.cast` op to propagate the transformation

mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -321,47 +321,78 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
321321

322322
// -----
323323

324-
func.func @rank_reduced_extract_slice(%cond : i1) -> tensor<6x2xf32> {
324+
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
325+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
326+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
327+
#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
328+
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
329+
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
330+
func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
331+
%c0 = arith.constant 0 : index
332+
%c2 = arith.constant 2 : index
333+
%c6 = arith.constant 6 : index
325334
%cst = arith.constant 0.0 : f32
326-
%cst1 = arith.constant 1.0 : f32
327-
328-
%empty1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
329-
%init1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%empty1 : tensor<6x6x1x1x1x1xf32>) {
330-
^bb0(%out: f32):
331-
linalg.yield %cst : f32
332-
} -> tensor<6x6x1x1x1x1xf32>
333-
334-
%if = scf.if %cond -> tensor<6x2xf32> {
335-
%extract0 = tensor.extract_slice %init1[0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
336-
337-
%init2 = tensor.empty() : tensor<6x2xf32>
338-
%add1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extract0 : tensor<6x2xf32>) outs(%init2 : tensor<6x2xf32>) {
339-
^bb0(%in: f32, %out: f32):
340-
%add = arith.addf %in, %cst1 : f32
341-
linalg.yield %add : f32
342-
} -> tensor<6x2xf32>
343-
scf.yield %add1 : tensor<6x2xf32>
344-
} else {
345-
%extract2 = tensor.extract_slice %init1[0, 2, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
346-
scf.yield %extract2 : tensor<6x2xf32>
335+
%init1 = tensor.empty() : tensor<1x6x6xf32>
336+
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
337+
%0 = linalg.generic
338+
{indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
339+
ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%fill1 : tensor<1x6x6xf32>) {
340+
^bb0(%in: f32, %in_1: f32, %out: f32):
341+
%10 = arith.mulf %in, %in_1 : f32
342+
%11 = arith.addf %out, %10 : f32
343+
linalg.yield %11 : f32
344+
} -> tensor<1x6x6xf32>
345+
%init2 = tensor.empty() : tensor<4x6xf32>
346+
%1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
347+
%2 = tensor.extract_slice %0[0, 0, %arg4] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
348+
%init3 = tensor.empty() : tensor<4x2xf32>
349+
%fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
350+
%3 = linalg.generic
351+
{indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
352+
ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
353+
^bb0(%in: f32, %in_1: f32, %out: f32):
354+
%20 = arith.mulf %in, %in_1 : f32
355+
%21 = arith.addf %out, %20 : f32
356+
linalg.yield %21 : f32
357+
} -> tensor<4x2xf32>
358+
%4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
359+
scf.yield %4 : tensor<4x6xf32>
347360
}
348-
349-
return %if : tensor<6x2xf32>
361+
return %1 : tensor<4x6xf32>
350362
}
351363

352364
// CHECK: func @rank_reduced_extract_slice(
353-
// CHECK-SAME: %[[COND:[0-9a-z]*]]: i1
354-
355-
// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
356-
// CHECK: %[[FILL_PROD:.*]] = linalg.generic
357-
// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>)
365+
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
366+
// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
367+
// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
358368

359-
// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<6x2xf32>
360-
// CHECK: %[[EXTRACT_SLICE_CONS:.*]] = tensor.extract_slice %[[EMPTY_PROD]][0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
361-
362-
// CHECK: %[[FILL_CONS:.*]] = linalg.generic
363-
// CHECK-SAME: outs(%[[EXTRACT_SLICE_CONS]] : tensor<6x2x1x1x1x1xf32>)
364-
// CHECK: %[[CONS_COLLAPSE:.*]] = tensor.collapse_shape %[[FILL_CONS]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
365-
// CHECK: %[[ADD1_CONS:.*]] = linalg.generic
366-
// CHECK-SAME: ins(%[[CONS_COLLAPSE]] : tensor<6x2xf32>)
367-
// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<6x2xf32>)
369+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
370+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
371+
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
372+
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
373+
// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<1x6x6xf32>
374+
// CHECK-NEXT: %[[FILL_PROD:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_PROD]] : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
375+
// CHECK-NEXT: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
376+
// CHECK-NEXT: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
377+
// CHECK-NEXT: %[[FILL_CONS:.*]] = linalg.fill ins(%[[CST]] : f32)
378+
379+
// For loop right after tensor alloc & fill, no linalg.generic.
380+
// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
381+
382+
// Producer linalg.generic now inside the loop, with tiled args sliced before
383+
// it.
384+
// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
385+
// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
386+
// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
387+
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
388+
// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
389+
//
390+
// Consumer uses a rank-reduced version of producer result so a collapse_shape
391+
// is generated.
392+
// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
393+
// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
394+
// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
395+
// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
396+
// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
397+
// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
398+
// CHECK: return %[[FOR]] : tensor<4x6xf32>

0 commit comments

Comments
 (0)