Skip to content

Commit 6fc320d

Browse files
committed
Address review comments
Utils: - drop comments on implementation - rename from into src Fusion: - restrict live range of droppedDims - clarify comment for rank-reduction check Test: - Use more descriptive SSA and FileCheck variables - Emphasize the rank-reducing extract_slice in the input IR as the key aspect of the test.
1 parent 42d8959 commit 6fc320d

File tree

4 files changed

+41
-36
lines changed

4 files changed

+41
-36
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ computeTransposedType(RankedTensorType rankedTensorType,
4444
ArrayRef<int64_t> transposeVector);
4545

4646
/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
47-
/// `from`.
48-
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
47+
/// `src`.
48+
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
4949
const llvm::SmallBitVector &dropDims);
5050

5151
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
256256
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
257257
return failure();
258258
}
259-
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
260259

261260
// If producer is already in the same block as consumer, we are done.
262261
if (consumerOpOperand.get().getParentBlock() ==
@@ -276,11 +275,14 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
276275
// Replace use.
277276
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
278277
Type consumerType = consumerOpOperand.get().getType();
279-
// Rank-reduction occurred as part of the extract_slice.
278+
// Check if rank-reduction occurred as part of the extract_slice. If yes,
279+
// collapse the dropped dimensions.
280280
if (cast<ShapedType>(consumerType).getRank() !=
281-
cast<ShapedType>(def.getType()).getRank())
281+
cast<ShapedType>(def.getType()).getRank()) {
282+
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
282283
def =
283284
tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
285+
}
284286
// Canonicalizations are not guaranteed to have happened before constructing
285287
// `fusedProducer`. In the tensor case this can result in temporary type
286288
// mismatches. Insert a `tensor.cast` op to propagate the transformation

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,16 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
9494
return transposedTensorType;
9595
}
9696

97-
/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
98-
/// `from`.
9997
CollapseShapeOp
100-
mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
98+
mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
10199
const llvm::SmallBitVector &dropDims) {
102-
auto fromType = cast<ShapedType>(from.getType());
103-
int64_t rank = fromType.getRank();
100+
auto srcType = cast<ShapedType>(src.getType());
101+
int64_t rank = srcType.getRank();
104102
assert(rank == static_cast<int64_t>(dropDims.size()) &&
105-
"dropDims dimension does not match from tensor rank");
103+
"dropDims dimension does not match src tensor rank");
106104
assert(llvm::all_of(
107105
dropDims.set_bits(),
108-
[&](unsigned dim) { return fromType.getShape()[dim] == 1; }) &&
106+
[&](unsigned dim) { return srcType.getShape()[dim] == 1; }) &&
109107
"Dropping non unit dimension");
110108
// Computed reassociation map for the corresponding tensor.collapse_shape.
111109
SmallVector<ReassociationIndices, 2> reassocMaps;
@@ -124,7 +122,7 @@ mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from,
124122
reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
125123
nextDimToGroup = setBit + 1;
126124
}
127-
return b.create<tensor::CollapseShapeOp>(loc, from, reassocMaps);
125+
return b.create<tensor::CollapseShapeOp>(loc, src, reassocMaps);
128126
}
129127

130128
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {

mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -328,66 +328,71 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
328328
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
329329
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
330330
func.func @rank_reduced_extract_slice(
331-
%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>,
332-
%arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32>
331+
%prod_in: tensor<1x6x5xf32>, %prod_weight: tensor<1x5x6xf32>,
332+
%cons_in: tensor<4x6xf32>, %prod_init: tensor<1x6x6xf32>,
333+
%for_iv_init: tensor<4x6xf32>, %cons_init: tensor<4x2xf32>
333334
) -> tensor<4x6xf32> {
334335
%c0 = arith.constant 0 : index
335336
%c2 = arith.constant 2 : index
336337
%c6 = arith.constant 6 : index
337-
%0 = linalg.generic
338+
%mmul_prod = linalg.generic
338339
{indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
339-
ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) {
340+
ins(%prod_in, %prod_weight : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%prod_init : tensor<1x6x6xf32>) {
340341
^bb0(%in: f32, %in_1: f32, %out: f32):
341342
%10 = arith.mulf %in, %in_1 : f32
342343
%11 = arith.addf %out, %10 : f32
343344
linalg.yield %11 : f32
344345
} -> tensor<1x6x6xf32>
345-
%1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) {
346-
%2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
347-
%3 = linalg.generic
346+
%for = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %for_iv_init) -> (tensor<4x6xf32>) {
347+
348+
// Extract slice with rank-reduced result type. When fused in the loop
349+
// with sliced operands, the producer linalg must have its now sliced
350+
// result be rank-reduced as well to match consumer's use type.
351+
%prod_slice = tensor.extract_slice %mmul_prod[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
352+
%mmul_cons = linalg.generic
348353
{indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
349-
ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%arg5 : tensor<4x2xf32>) {
354+
ins(%cons_in, %prod_slice : tensor<4x6xf32>, tensor<6x2xf32>) outs(%cons_init : tensor<4x2xf32>) {
350355
^bb0(%in: f32, %in_1: f32, %out: f32):
351356
%20 = arith.mulf %in, %in_1 : f32
352357
%21 = arith.addf %out, %20 : f32
353358
linalg.yield %21 : f32
354359
} -> tensor<4x2xf32>
355-
%4 = tensor.insert_slice %3 into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
360+
%4 = tensor.insert_slice %mmul_cons into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
356361
scf.yield %4 : tensor<4x6xf32>
357362
}
358-
return %1 : tensor<4x6xf32>
363+
return %for : tensor<4x6xf32>
359364
}
360365

361366
// CHECK: func @rank_reduced_extract_slice(
362-
// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32>
363-
// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32>
364-
// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
365-
// CHECK-SAME: %[[ARG3:[0-9a-z]*]]: tensor<1x6x6xf32>
366-
// CHECK-SAME: %[[ARG4:[0-9a-z]*]]: tensor<4x6xf32>
367-
// CHECK-SAME: %[[ARG5:[0-9a-z]*]]: tensor<4x2xf32>
367+
// CHECK-SAME: %[[PROD_IN:[0-9a-z]*]]: tensor<1x6x5xf32>
368+
// CHECK-SAME: %[[PROD_WEIGHT:[0-9a-z]*]]: tensor<1x5x6xf32>
369+
// CHECK-SAME: %[[CONS_IN:[0-9a-z]*]]: tensor<4x6xf32>
370+
// CHECK-SAME: %[[PROD_INIT:[0-9a-z]*]]: tensor<1x6x6xf32>
371+
// CHECK-SAME: %[[FOR_IV_INIT:[0-9a-z]*]]: tensor<4x6xf32>
372+
// CHECK-SAME: %[[CONS_INIT:[0-9a-z]*]]: tensor<4x2xf32>
368373

369374
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
370375
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
371376
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
372377

373378
// For loop right after tensor alloc & fill, no linalg.generic.
374379
// CHECK-NOT: linalg.generic
375-
// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4]])
380+
// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[FOR_IV_INIT]])
376381

377382
// Producer linalg.generic now inside the loop, with tiled args sliced before
378383
// it.
379-
// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
380-
// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[ARG3]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
384+
// CHECK-DAG: %[[PROD_WEIGHT_SLICE:.*]] = tensor.extract_slice %[[PROD_WEIGHT]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
385+
// CHECK-DAG: %[[PROD_INIT_SLICE:.*]] = tensor.extract_slice %[[PROD_INIT]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
381386
// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
382-
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
383-
// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>)
387+
// CHECK-SAME: ins(%[[PROD_IN]], %[[PROD_WEIGHT_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
388+
// CHECK-SAME: outs(%[[PROD_INIT_SLICE]] : tensor<1x6x2xf32>)
384389
//
385390
// Consumer uses a rank-reduced version of producer result so a collapse_shape
386391
// is generated.
387392
// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
388393
// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
389-
// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
390-
// CHECK-SAME: outs(%[[ARG5]] : tensor<4x2xf32>)
394+
// CHECK-SAME: ins(%[[CONS_IN]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
395+
// CHECK-SAME: outs(%[[CONS_INIT]] : tensor<4x2xf32>)
391396
// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
392397
// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
393398
// CHECK: return %[[FOR]] : tensor<4x6xf32>

0 commit comments

Comments
 (0)