@@ -20,13 +20,26 @@ namespace mlir {
2020
2121using namespace mlir ;
2222
23+ static inline bool isScalarLike (Type t) {
24+ return isa<IntegerType, FloatType, IndexType, ComplexType>(t);
25+ }
26+
2327static bool isElementwiseMappableOpOnRankedTensors (Operation *op) {
2428 if (!OpTrait::hasElementwiseMappableTraits (op))
2529 return false ;
2630
27- // TODO: The conversion pattern can be made to work for `any_of` here, but
28- // it's more complex as it requires tracking which operands are scalars.
29- return llvm::all_of (op->getOperandTypes (), llvm::IsaPred<RankedTensorType>);
31+ auto types = op->getOperandTypes ();
32+
33+ // We want at least one ranked tensor.
34+ bool anyRankedTensor = llvm::any_of (types, llvm::IsaPred<RankedTensorType>);
35+
36+ // No invalid operands (i.e., every operand is a ranked tensor or
37+ // scalar-like).
38+ bool noneInvalid = llvm::none_of (types, [](Type t) {
39+ return !(isa<RankedTensorType>(t) || isScalarLike (t));
40+ });
41+
42+ return anyRankedTensor && noneInvalid;
3043}
3144
3245// / Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over
@@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
8194 return rewriter.notifyMatchFailure (
8295 op, " requires elementwise op on ranked tensors" );
8396
84- auto rank = cast<RankedTensorType>(op->getResult (0 ).getType ()).getRank ();
85- SmallVector<AffineMap, 3 > indexingMaps (
86- op->getNumResults () + op->getNumOperands (),
87- rewriter.getMultiDimIdentityMap (rank));
88- SmallVector<utils::IteratorType, 6 > iteratorTypes (
97+ auto resTy = cast<RankedTensorType>(op->getResult (0 ).getType ());
98+ auto rank = resTy.getRank ();
99+
100+ // Maps: identity for tensors (rank > 0), scalar map for scalars.
101+ AffineMap scalarMap = AffineMap::get (/* dimCount=*/ rank, /* symbolCount=*/ 0 ,
102+ /* results=*/ {}, rewriter.getContext ());
103+ AffineMap idMap = rewriter.getMultiDimIdentityMap (rank);
104+
105+ // Match phase.
106+ SmallVector<bool > isScalarOperand;
107+ isScalarOperand.reserve (op->getNumOperands ());
108+ for (Type ty : op->getOperandTypes ()) {
109+ if (isScalarLike (ty))
110+ isScalarOperand.push_back (true );
111+ else if (auto rt = dyn_cast<RankedTensorType>(ty))
112+ isScalarOperand.push_back (false );
113+ else
114+ return rewriter.notifyMatchFailure (
115+ op,
116+ " unsupported operand type (expected scalar-like or ranked tensor)" );
117+ }
118+
119+ // Create indexing maps.
120+ SmallVector<AffineMap> indexingMaps;
121+ indexingMaps.reserve (op->getNumOperands () + op->getNumResults ());
122+
123+ for (bool isScalar : isScalarOperand)
124+ indexingMaps.push_back (isScalar ? scalarMap : idMap);
125+
126+ indexingMaps.append (op->getNumResults (), idMap);
127+
128+ SmallVector<utils::IteratorType> iteratorTypes (
89129 rank, utils::IteratorType::parallel);
90- auto outputs = getOrCreateOperandsMatchingResultTypes (rewriter, op);
130+ SmallVector<Value> outputs =
131+ getOrCreateOperandsMatchingResultTypes (rewriter, op);
91132 rewriter.replaceOpWithNewOp <linalg::GenericOp>(
92133 op, /* resultTensorTypes=*/ op->getResultTypes (),
93134 /* inputs=*/ op->getOperands (),
@@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
96137 /* iteratorTypes=*/ iteratorTypes,
97138 /* bodyBuilder=*/
98139 [&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
99- auto resultTypes = llvm::to_vector<6 >(
140+ SmallVector<Type> resultEltTys = llvm::to_vector<6 >(
100141 llvm::map_range (op->getResultTypes (), [](Type type) {
101142 return cast<TensorType>(type).getElementType ();
102143 }));
103- auto *scalarOp =
144+ Operation *scalarOp =
104145 builder.create (loc, op->getName ().getIdentifier (),
105146 regionArgs.take_front (op->getNumOperands ()),
106- resultTypes , op->getAttrs ());
147+ resultEltTys , op->getAttrs ());
107148 linalg::YieldOp::create (builder, loc, scalarOp->getResults ());
108149 });
109150 return success ();
0 commit comments