Skip to content

Commit 7f76ec9

Browse files
authored
Provide tile-and-fuse for SoftMax (#540)
1 parent fe7b7b7 commit 7f76ec9

File tree

3 files changed

+83
-12
lines changed

3 files changed

+83
-12
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
101101
["getIterationDomain",
102102
"getLoopIteratorTypes",
103103
"getResultTilePosition",
104-
"getTiledImplementation"]>]> {
104+
"getTiledImplementation",
105+
"generateResultTileValue"]>]> {
105106
let summary = "Softmax operator";
106107
let description = [{
107108
linalg.softmax computes a numerically stable version of softmax.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,39 +2648,56 @@ SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
26482648
return iteratorTypes;
26492649
}
26502650

2651-
FailureOr<TilingResult>
2652-
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2653-
ArrayRef<OpFoldResult> offsets,
2654-
ArrayRef<OpFoldResult> sizes) {
2655-
int64_t rank = getInputOperandRank();
2651+
static FailureOr<TilingResult>
2652+
implementTiledSoftMax(SoftmaxOp &op, OpBuilder &builder,
2653+
ArrayRef<OpFoldResult> offsets,
2654+
ArrayRef<OpFoldResult> sizes) {
2655+
int64_t rank = op.getInputOperandRank();
26562656
auto oneAttr = builder.getI64IntegerAttr(1);
26572657
SmallVector<OpFoldResult> strides(rank, oneAttr);
26582658
SmallVector<Value> tiledOperands;
26592659
Operation *inputSlice =
2660-
getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2660+
getSlice(builder, op.getLoc(), op.getInput(), offsets, sizes, strides);
26612661
if (!inputSlice) {
2662-
return emitOpError("failed to compute input slice");
2662+
return op.emitOpError("failed to compute input slice");
26632663
}
26642664
tiledOperands.emplace_back(inputSlice->getResult(0));
26652665
Operation *outputSlice =
2666-
getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2666+
getSlice(builder, op.getLoc(), op.getOutput(), offsets, sizes, strides);
26672667
if (!outputSlice) {
2668-
return emitOpError("failed to compute output slice");
2668+
return op.emitOpError("failed to compute output slice");
26692669
}
26702670
tiledOperands.emplace_back(outputSlice->getResult(0));
26712671

26722672
SmallVector<Type, 4> resultTypes;
2673-
if (hasPureTensorSemantics())
2673+
if (op.hasPureTensorSemantics())
26742674
resultTypes.push_back(tiledOperands[1].getType());
26752675
Operation *tiledOp =
2676-
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2676+
mlir::clone(builder, op.getOperation(), resultTypes, tiledOperands);
26772677

26782678
return TilingResult{
26792679
{tiledOp},
26802680
SmallVector<Value>(tiledOp->getResults()),
26812681
llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
26822682
}
26832683

2684+
FailureOr<TilingResult>
2685+
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2686+
ArrayRef<OpFoldResult> offsets,
2687+
ArrayRef<OpFoldResult> sizes) {
2688+
return implementTiledSoftMax(*this, builder, offsets, sizes);
2689+
}
2690+
2691+
FailureOr<TilingResult>
2692+
SoftmaxOp::generateResultTileValue(OpBuilder &builder, unsigned resultNumber,
2693+
ArrayRef<OpFoldResult> offsets,
2694+
ArrayRef<OpFoldResult> sizes) {
2695+
if (resultNumber != 0)
2696+
return failure();
2697+
2698+
return implementTiledSoftMax(*this, builder, offsets, sizes);
2699+
}
2700+
26842701
LogicalResult SoftmaxOp::getResultTilePosition(
26852702
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
26862703
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,

mlir/test/Dialect/Linalg/tile-softmax.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,56 @@ module attributes {transform.with_named_sequence} {
153153
transform.yield
154154
}
155155
}
156+
157+
// -----
158+
159+
// Check that tile-and-fuse works with SoftMax, i.e. that tiling of a SoftMax op
160+
// from a tile of its consumer is correct.
161+
// For this, use the FuseAndYield transform op.
162+
163+
// CHECK-LABEL: @softmax_tile_from_consumer(
164+
// CHECK-SAME: %[[ARG0:[^:]*]]: tensor<16x64x256xf32>
165+
166+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
167+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
168+
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
169+
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
170+
// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
171+
172+
// CHECK: %[[EMPTY:[^ ]*]] = tensor.empty() : tensor<16x64x256xf32>
173+
174+
// CHECK: scf.for %[[I0:[^ ]*]] = %[[C0]] to %[[C16]] step %[[C2]]
175+
// CHECK: scf.for %[[I1:[^ ]*]] = %[[C0]] to %[[C256]] step %[[C8]]
176+
177+
// CHECK: %[[SOFTMAX_IN:[^ ]*]] = tensor.extract_slice %[[ARG0]][%[[I0]], 0, %[[I1]]] [2, 64, 8] [1, 1, 1]
178+
// CHECK: %[[SOFTMAX:[^ ]*]] = linalg.softmax dimension(1) ins(%[[SOFTMAX_IN]] : tensor<2x64x8xf32>)
179+
// CHECK: %[[GENERIC:[^ ]*]] = linalg.generic {{.*}} ins(%[[SOFTMAX]] : tensor<2x64x8xf32>)
180+
181+
func.func @softmax_tile_from_consumer(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
182+
%cst = arith.constant 1.000000e+00 : f32
183+
%empty0 = tensor.empty() : tensor<16x64x256xf32>
184+
%1 = linalg.softmax
185+
dimension(1) ins(%arg0 : tensor<16x64x256xf32>) outs(%empty0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
186+
%empty1 = tensor.empty() : tensor<16x64x256xf32>
187+
%eltwise = linalg.generic
188+
{indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
189+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
190+
iterator_types = ["parallel", "parallel", "parallel"]
191+
}
192+
ins(%1 : tensor<16x64x256xf32>)
193+
outs(%empty1 : tensor<16x64x256xf32>) {
194+
^bb0(%arg2: f32, %arg3: f32):
195+
%arg2Plus1 = arith.addf %arg2, %cst : f32
196+
linalg.yield %arg2Plus1 : f32
197+
} -> tensor<16x64x256xf32>
198+
199+
return %eltwise : tensor<16x64x256xf32>
200+
}
201+
202+
module attributes {transform.with_named_sequence} {
203+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
204+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
205+
%1, %loop:3 = transform.test.fuse_and_yield %0 [2, 64, 8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
206+
transform.yield
207+
}
208+
}

0 commit comments

Comments
 (0)