Skip to content

Commit 9801a0f

Browse files
authored
[mlir] Add helper to check elementwise-mappable ops with tensors and scalars (#154872)
This patch introduces a more general helper for identifying elementwise-mappable operations. The existing utility, `isElementwiseMappableOpOnRankedTensors`, only accepted operations when all operands were ranked tensors. In practice, many elementwise operations in MLIR allow mixing tensor operands with scalars. The new helper relaxes the restriction by accepting operands that are either ranked tensors or “scalar-like” types.
1 parent a3e2b64 commit 9801a0f

File tree

2 files changed

+119
-12
lines changed

2 files changed

+119
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,26 @@ namespace mlir {
2020

2121
using namespace mlir;
2222

23+
static inline bool isScalarLike(Type t) {
24+
return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
25+
}
26+
2327
static bool isElementwiseMappableOpOnRankedTensors(Operation *op) {
2428
if (!OpTrait::hasElementwiseMappableTraits(op))
2529
return false;
2630

27-
// TODO: The conversion pattern can be made to work for `any_of` here, but
28-
// it's more complex as it requires tracking which operands are scalars.
29-
return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>);
31+
auto types = op->getOperandTypes();
32+
33+
// We want at least one ranked tensor.
34+
bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>);
35+
36+
// No invalid operands (i.e., every operand is a ranked tensor or
37+
// scalar-like).
38+
bool noneInvalid = llvm::none_of(types, [](Type t) {
39+
return !(isa<RankedTensorType>(t) || isScalarLike(t));
40+
});
41+
42+
return anyRankedTensor && noneInvalid;
3043
}
3144

3245
/// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
@@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
8194
return rewriter.notifyMatchFailure(
8295
op, "requires elementwise op on ranked tensors");
8396

84-
auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank();
85-
SmallVector<AffineMap, 3> indexingMaps(
86-
op->getNumResults() + op->getNumOperands(),
87-
rewriter.getMultiDimIdentityMap(rank));
88-
SmallVector<utils::IteratorType, 6> iteratorTypes(
97+
auto resTy = cast<RankedTensorType>(op->getResult(0).getType());
98+
auto rank = resTy.getRank();
99+
100+
// Maps: identity for tensors (rank > 0), scalar map for scalars.
101+
AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0,
102+
/*results=*/{}, rewriter.getContext());
103+
AffineMap idMap = rewriter.getMultiDimIdentityMap(rank);
104+
105+
// Match phase.
106+
SmallVector<bool> isScalarOperand;
107+
isScalarOperand.reserve(op->getNumOperands());
108+
for (Type ty : op->getOperandTypes()) {
109+
if (isScalarLike(ty))
110+
isScalarOperand.push_back(true);
111+
else if (auto rt = dyn_cast<RankedTensorType>(ty))
112+
isScalarOperand.push_back(false);
113+
else
114+
return rewriter.notifyMatchFailure(
115+
op,
116+
"unsupported operand type (expected scalar-like or ranked tensor)");
117+
}
118+
119+
// Create indexing maps.
120+
SmallVector<AffineMap> indexingMaps;
121+
indexingMaps.reserve(op->getNumOperands() + op->getNumResults());
122+
123+
for (bool isScalar : isScalarOperand)
124+
indexingMaps.push_back(isScalar ? scalarMap : idMap);
125+
126+
indexingMaps.append(op->getNumResults(), idMap);
127+
128+
SmallVector<utils::IteratorType> iteratorTypes(
89129
rank, utils::IteratorType::parallel);
90-
auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op);
130+
SmallVector<Value> outputs =
131+
getOrCreateOperandsMatchingResultTypes(rewriter, op);
91132
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
92133
op, /*resultTensorTypes=*/op->getResultTypes(),
93134
/*inputs=*/op->getOperands(),
@@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
96137
/*iteratorTypes=*/iteratorTypes,
97138
/*bodyBuilder=*/
98139
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
99-
auto resultTypes = llvm::to_vector<6>(
140+
SmallVector<Type> resultEltTys = llvm::to_vector<6>(
100141
llvm::map_range(op->getResultTypes(), [](Type type) {
101142
return cast<TensorType>(type).getElementType();
102143
}));
103-
auto *scalarOp =
144+
Operation *scalarOp =
104145
builder.create(loc, op->getName().getIdentifier(),
105146
regionArgs.take_front(op->getNumOperands()),
106-
resultTypes, op->getAttrs());
147+
resultEltTys, op->getAttrs());
107148
linalg::YieldOp::create(builder, loc, scalarOp->getResults());
108149
});
109150
return success();

mlir/test/Dialect/Linalg/convert-elementwise-to-linalg.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,69 @@ func.func @cmpf(%arg0: tensor<4x?x?x8x2x?xf32>, %arg1: tensor<4x?x?x8x2x?xf32>)
108108
return %0 : tensor<4x?x?x8x2x?xi1>
109109
}
110110

111+
// -----
112+
113+
// Check a mix of scalar and tensor input.
114+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> ()>
115+
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
116+
// CHECK-LABEL: func @scalar_plus_tensor
117+
func.func @scalar_plus_tensor(%arg0: f32, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
118+
// CHECK: %[[GEN:.*]] = linalg.generic
119+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
120+
// CHECK-SAME: ins(%[[S:.*]], %[[T:.*]] : f32, tensor<?x?xf32>)
121+
// CHECK-SAME: outs(%[[T]] : tensor<?x?xf32>)
122+
// CHECK: ^bb0(%[[SB:.*]]: f32, %[[TB:.*]]: f32, %[[OB:.*]]: f32):
123+
// CHECK: "test.elementwise_mappable"(%[[SB]], %[[TB]]) : (f32, f32) -> f32
124+
// CHECK: linalg.yield {{.*}} : f32
125+
// CHECK: } -> tensor<?x?xf32>
126+
%0 = "test.elementwise_mappable"(%arg0, %arg1)
127+
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
128+
return %0 : tensor<?x?xf32>
129+
}
130+
131+
// -----
132+
// This test exercises the case where an elementwise op has two scalar-like
133+
// operands and one ranked tensor operand. In this example, we chain two
134+
// `test.elementwise_mappable` calls:
135+
// %0 = f(%s1, %t)
136+
// %1 = f(%s2, %0)
137+
// CHECK-DAG: #[[$SC2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> ()>
138+
// CHECK-DAG: #[[$ID2:[A-Za-z0-9_]+]] = affine_map<(d0, d1) -> (d0, d1)>
139+
// CHECK-LABEL: func @scalar_tensor_scalar
140+
func.func @scalar_tensor_scalar(%s1: f32, %t: tensor<?x?xf32>, %s2: f32) -> tensor<?x?xf32> {
141+
// First generic.
142+
// CHECK: %[[GEN0:.*]] = linalg.generic
143+
// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
144+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
145+
// CHECK-SAME: ins(%[[S1:[^,]+]], %[[T0:[^)]*]] : f32, tensor<?x?xf32>)
146+
// CHECK-SAME: outs(%[[T0]] : tensor<?x?xf32>)
147+
// CHECK: ^bb0(%[[S1E:.*]]: f32, %[[T0E:.*]]: f32, %[[O0E:.*]]: f32):
148+
// CHECK: %[[APPLY0:.*]] = "test.elementwise_mappable"(%[[S1E]], %[[T0E]]) : (f32, f32) -> f32
149+
// CHECK: linalg.yield %[[APPLY0]] : f32
150+
// CHECK: } -> tensor<?x?xf32>
151+
152+
// Second generic.
153+
// CHECK: %[[GEN1:.*]] = linalg.generic
154+
// CHECK-SAME: indexing_maps = [#[[$SC2]], #[[$ID2]], #[[$ID2]]]
155+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
156+
// CHECK-SAME: ins(%[[S2:[^,]+]], %[[GEN0]] : f32, tensor<?x?xf32>)
157+
// CHECK-SAME: outs(%[[GEN0]] : tensor<?x?xf32>)
158+
// CHECK: ^bb0(%[[S2E:.*]]: f32, %[[G0E:.*]]: f32, %[[O1E:.*]]: f32):
159+
// CHECK: %[[APPLY1:.*]] = "test.elementwise_mappable"(%[[S2E]], %[[G0E]]) : (f32, f32) -> f32
160+
// CHECK: linalg.yield %[[APPLY1]] : f32
161+
// CHECK: } -> tensor<?x?xf32>
162+
// CHECK: return %[[GEN1]] : tensor<?x?xf32>
163+
%0 = "test.elementwise_mappable"(%s1, %t)
164+
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
165+
%1 = "test.elementwise_mappable"(%s2, %0)
166+
: (f32, tensor<?x?xf32>) -> tensor<?x?xf32>
167+
return %1 : tensor<?x?xf32>
168+
}
169+
170+
// ----
171+
// CHECK-LABEL: func @negative_scalar_only_eltwise
172+
// CHECK-NOT: linalg
173+
func.func @negative_scalar_only_eltwise(%a: f32, %b: f32) -> f32 {
174+
%0 = arith.addf %a, %b : f32
175+
return %0 : f32
176+
}

0 commit comments

Comments
 (0)