Skip to content

Commit 8855028

Browse files
Fix crash on ComplexType in PointwiseToLinalgMapConverter (#2754)
A recent change added code to greedily materialize splat constants, but the code would crash when used with `complex<..>` types.
1 parent 9b65803 commit 8855028

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

stablehlo/conversions/linalg/tests/pointwise.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,29 @@ func.func @float_add(%lhs: tensor<2x2xf32>,
2323

2424
// -----
2525

26+
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
27+
// CHECK-LABEL: func @complex_add_const
28+
// CHECK-PRIMITIVE-LABEL: func @complex_add_const
29+
func.func @complex_add_const(%lhs: tensor<2x2xcomplex<f32>>,
30+
%rhs: tensor<2x2xcomplex<f32>>)
31+
-> tensor<2x2xcomplex<f32>> {
32+
33+
// CHECK: %[[CST:.+]] = complex.constant [1.000000e-01 : f32, 2.000000e-01 : f32] : complex<f32>
34+
// CHECK: linalg.generic
35+
// CHECK: ^bb0(%[[IN:.+]]: complex<f32>, %[[OUT:.+]]: complex<f32>)
36+
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = complex.add %[[IN]], %[[CST]]
37+
// CHECK: linalg.yield %[[RESULT]]
38+
39+
// CHECK-PRIMITIVE: linalg.map
40+
// CHECK-PRIMITIVE: complex.add
41+
%cst = stablehlo.constant dense<(0.1, 0.2)> : tensor<2x2xcomplex<f32>>
42+
%0 = "stablehlo.add"(%lhs, %cst) {someattr}
43+
: (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
44+
func.return %0 : tensor<2x2xcomplex<f32>>
45+
}
46+
47+
// -----
48+
2649
// CHECK-LABEL: func @float_add_dynamic_encoding
2750
// CHECK-PRIMITIVE-LABEL: func @float_add_dynamic_encoding
2851
func.func @float_add_dynamic_encoding(

stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,28 @@ FailureOr<PointwiseConversionInfo> checkOperandsAndResults(
114114
return PointwiseConversionInfo{maxRank, resultTy};
115115
}
116116

117+
/// If `input` is a splat constant value, materialize the scalar splat
118+
/// value. Otherwise, return nullopt.
119+
std::optional<Value> materializeSplatScalarConstant(RewriterBase &rewriter,
120+
Location loc, Value input) {
121+
SplatElementsAttr attr;
122+
Type elementType = mlir::getElementTypeOrSelf(input.getType());
123+
if (!matchPattern(input, m_Constant(&attr))) return {};
124+
if (isa<IntegerType, FloatType, IndexType>(elementType)) {
125+
return rewriter
126+
.create<arith::ConstantOp>(loc, elementType,
127+
attr.getSplatValue<TypedAttr>())
128+
.getResult();
129+
}
130+
if (isa<ComplexType>(elementType)) {
131+
return rewriter
132+
.create<complex::ConstantOp>(loc, elementType,
133+
attr.getSplatValue<ArrayAttr>())
134+
.getResult();
135+
}
136+
return {};
137+
}
138+
117139
/// Converts a HLO operation to a linalg.map op that contains the corresponding
118140
/// scalar operations.
119141
template <typename OpTy>
@@ -160,11 +182,9 @@ struct PointwiseToLinalgMapConverter : OpConversionPattern<OpTy> {
160182
SmallVector<Value> mappedInputs;
161183
SmallVector<Value> scalarInputs;
162184
for (Value input : adaptor.getOperands()) {
163-
DenseElementsAttr attr;
164-
if (matchPattern(input, m_Constant(&attr)) && attr.isSplat()) {
165-
scalarInputs.push_back(rewriter.create<arith::ConstantOp>(
166-
loc, cast<ShapedType>(input.getType()).getElementType(),
167-
attr.getSplatValue<TypedAttr>()));
185+
if (std::optional<Value> splatVal =
186+
materializeSplatScalarConstant(rewriter, loc, input)) {
187+
scalarInputs.push_back(*splatVal);
168188
} else if (getRank(input) == maxRank) {
169189
mappedInputs.push_back(coerceTensorShape(
170190
rewriter, loc, cast<TypedValue<ShapedType>>(input),

0 commit comments

Comments
 (0)