Skip to content

Commit 49deedf

Browse files
committed
Clean up code
dropGivenUnitDims(): - move assert out of loop - rework algorithm to make grouping more explicit and avoid complex nested ifs - fix occured typo Test: remove all tensor.empty and linalg.fill
1 parent b19b649 commit 49deedf

File tree

2 files changed

+36
-43
lines changed

2 files changed

+36
-43
lines changed

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -242,32 +242,30 @@ static tensor::CollapseShapeOp
242242
dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
243243
const llvm::SmallBitVector &dropDims) {
244244
auto fromType = cast<ShapedType>(from.getType());
245-
assert(fromType.getRank() == static_cast<int64_t>(dropDims.size()) &&
245+
int64_t rank = fromType.getRank();
246+
assert(rank == static_cast<int64_t>(dropDims.size()) &&
246247
"dropDims dimension does not match from tensor rank");
248+
assert(llvm::all_of(
249+
dropDims.set_bits(),
250+
[&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
251+
"Dropping non unit dimension");
247252
// Computed reassociation map for the corresponding tensor.collapse_shape.
248253
SmallVector<ReassociationIndices, 2> reassocMaps;
249254
// Current reassociation group to add dropped dimension to.
250-
ReassociationIndices reassocGroup;
251-
252-
bool foundKeptDim = false;
253-
// Dropped dimensions might be at the beginning or end of the shape so
254-
// combine all contiguous dimensions before and after a given non dropped
255-
// dimension in reassocGroup until another non dropped dimension is found.
256-
// When that happens, add the reassociation indices to the map.
257-
for (int dim = 0; dim < fromType.getRank(); dim++) {
258-
if (dropDims.test(dim))
259-
assert(fromType.getShape()[dim] == 1 && "Dropping non unit dimension");
260-
else {
261-
if (foundKeptDim) {
262-
reassocMaps.push_back(reassocGroup);
263-
reassocGroup.clear();
264-
}
265-
foundKeptDim = true;
266-
}
267-
reassocGroup.push_back(dim);
255+
256+
int64_t nextDimToGroup = 0;
257+
llvm::SmallBitVector keptDims(dropDims);
258+
keptDims.flip();
259+
int64_t lastSetBit = keptDims.find_last();
260+
for(int64_t setBit : keptDims.set_bits()) {
261+
// Group consecutive dropped dimension with the next non-dropped dimension.
262+
// If this is the last set dimension, also group all subsequent dropped
263+
// dimension, if any.
264+
int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
265+
auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
266+
reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
267+
nextDimToGroup = setBit + 1;
268268
}
269-
if (!reassocGroup.empty())
270-
reassocMaps.push_back(reassocGroup);
271269
return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
272270
}
273271

@@ -311,7 +309,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
311309
// Replace use.
312310
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
313311
Type consumerType = consumerOpOperand.get().getType();
314-
// Rank-reduction occured as part of the extract_slice.
312+
// Rank-reduction occurred as part of the extract_slice.
315313
if (cast<ShapedType>(consumerType).getRank() !=
316314
cast<ShapedType>(def.getType()).getRank())
317315
def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);

mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -327,35 +327,32 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
327327
#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
328328
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
329329
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
330-
func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
330+
func.func @rank_reduced_extract_slice(
331+
%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>,
332+
%arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32>
333+
) -> tensor<4x6xf32> {
331334
%c0 = arith.constant 0 : index
332335
%c2 = arith.constant 2 : index
333336
%c6 = arith.constant 6 : index
334-
%cst = arith.constant 0.0 : f32
335-
%init1 = tensor.empty() : tensor<1x6x6xf32>
336-
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
337337
%0 = linalg.generic
338338
{indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
339-
ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%fill1 : tensor<1x6x6xf32>) {
339+
ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) {
340340
^bb0(%in: f32, %in_1: f32, %out: f32):
341341
%10 = arith.mulf %in, %in_1 : f32
342342
%11 = arith.addf %out, %10 : f32
343343
linalg.yield %11 : f32
344344
} -> tensor<1x6x6xf32>
345-
%init2 = tensor.empty() : tensor<4x6xf32>
346-
%1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
347-
%2 = tensor.extract_slice %0[0, 0, %arg4] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
348-
%init3 = tensor.empty() : tensor<4x2xf32>
349-
%fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
345+
%1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) {
346+
%2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
350347
%3 = linalg.generic
351348
{indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
352-
ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
349+
ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%arg5 : tensor<4x2xf32>) {
353350
^bb0(%in: f32, %in_1: f32, %out: f32):
354351
%20 = arith.mulf %in, %in_1 : f32
355352
%21 = arith.addf %out, %20 : f32
356353
linalg.yield %21 : f32
357354
} -> tensor<4x2xf32>
358-
%4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
355+
%4 = tensor.insert_slice %3 into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
359356
scf.yield %4 : tensor<4x6xf32>
360357
}
361358
return %1 : tensor<4x6xf32>
@@ -365,24 +362,22 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x
365362
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
366363
// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
367364
// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
365+
// CHECK-SAME: %[[ARG3:[0-9a-z]*]]: tensor<1x6x6xf32>
366+
// CHECK-SAME: %[[ARG4:[0-9a-z]*]]: tensor<4x6xf32>
367+
// CHECK-SAME: %[[ARG5:[0-9a-z]*]]: tensor<4x2xf32>
368368

369369
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
370370
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
371371
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
372-
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
373-
// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<1x6x6xf32>
374-
// CHECK-NEXT: %[[FILL_PROD:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_PROD]] : tensor<1x6x6xf32>) -> tensor<1x6x6xf32>
375-
// CHECK-NEXT: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
376-
// CHECK-NEXT: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
377-
// CHECK-NEXT: %[[FILL_CONS:.*]] = linalg.fill ins(%[[CST]] : f32)
378372

379373
// For loop right after tensor alloc & fill, no linalg.generic.
380-
// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
374+
// CHECK-NOT: linalg.generic
375+
// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4]])
381376

382377
// Producer linalg.generic now inside the loop, with tiled args sliced before
383378
// it.
384379
// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
385-
// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
380+
// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[ARG3]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
386381
// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
387382
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
388383
// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
@@ -392,7 +387,7 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x
392387
// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
393388
// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
394389
// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
395-
// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
390+
// CHECK-SAME: outs(%[[ARG5]] : tensor<4x2xf32>)
396391
// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
397392
// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
398393
// CHECK: return %[[FOR]] : tensor<4x6xf32>

0 commit comments

Comments
 (0)