@@ -114,6 +114,28 @@ FailureOr<PointwiseConversionInfo> checkOperandsAndResults(
114114 return PointwiseConversionInfo{maxRank, resultTy};
115115}
116116
117+ // / If `input` is a splat constant value, materialize the scalar splat
118+ // / value. Otherwise, return nullopt.
119+ std::optional<Value> materializeSplatScalarConstant (RewriterBase &rewriter,
120+ Location loc, Value input) {
121+ SplatElementsAttr attr;
122+ Type elementType = mlir::getElementTypeOrSelf (input.getType ());
123+ if (!matchPattern (input, m_Constant (&attr))) return {};
124+ if (isa<IntegerType, FloatType, IndexType>(elementType)) {
125+ return rewriter
126+ .create <arith::ConstantOp>(loc, elementType,
127+ attr.getSplatValue <TypedAttr>())
128+ .getResult ();
129+ }
130+ if (isa<ComplexType>(elementType)) {
131+ return rewriter
132+ .create <complex ::ConstantOp>(loc, elementType,
133+ attr.getSplatValue <ArrayAttr>())
134+ .getResult ();
135+ }
136+ return {};
137+ }
138+
117139// / Converts a HLO operation to a linalg.map op that contains the corresponding
118140// / scalar operations.
119141template <typename OpTy>
@@ -160,11 +182,9 @@ struct PointwiseToLinalgMapConverter : OpConversionPattern<OpTy> {
160182 SmallVector<Value> mappedInputs;
161183 SmallVector<Value> scalarInputs;
162184 for (Value input : adaptor.getOperands ()) {
163- DenseElementsAttr attr;
164- if (matchPattern (input, m_Constant (&attr)) && attr.isSplat ()) {
165- scalarInputs.push_back (rewriter.create <arith::ConstantOp>(
166- loc, cast<ShapedType>(input.getType ()).getElementType (),
167- attr.getSplatValue <TypedAttr>()));
185+ if (std::optional<Value> splatVal =
186+ materializeSplatScalarConstant (rewriter, loc, input)) {
187+ scalarInputs.push_back (*splatVal);
168188 } else if (getRank (input) == maxRank) {
169189 mappedInputs.push_back (coerceTensorShape (
170190 rewriter, loc, cast<TypedValue<ShapedType>>(input),
0 commit comments