@@ -166,7 +166,14 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
166166 // out which operand can supply that runtime-value (tensor.dim).
167167 // Leaving it as a future TODO.
168168 if (llvm::any_of (op->getOpOperands (), [](OpOperand &oper) {
169- auto opType = cast<RankedTensorType>(oper.get ().getType ());
169+ // Allow scalar values as these can be broadcasted on the input.
170+ if (oper.get ().getType ().isIntOrFloat ())
171+ return false ;
172+ // If any of the operands are not a RankedTensorType, then we should
173+ // return early. The pattern has been built with RankedTensors in mind.
174+ if (!isa<RankedTensorType>(oper.get ().getType ()))
175+ return true ;
176+ auto opType = cast<ShapedType>(oper.get ().getType ());
170177 return ShapedType::isDynamicShape (opType.getShape ());
171178 }))
172179 return failure ();
@@ -181,10 +188,27 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
181188 // Walk over each input operand and unfold if it is transposed, broadcast
182189 // or mix of two via operand's affine-map.
183190 for (int64_t i = 0 ; i < op.getNumDpsInputs (); ++i) {
184- auto &map = newMap[i];
185- auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType ());
186- auto elType = inputRTType.getElementType ();
191+ auto inputType = newInitValues[i].getType ();
192+ SmallVector<int64_t > inputShape =
193+ llvm::TypeSwitch<Type, SmallVector<int64_t >>(inputType)
194+ .Case ([](RankedTensorType tensor) { return tensor.getShape (); })
195+ .Case ([](FloatType scalar) { return SmallVector<int64_t >({1 }); })
196+ .Case ([](IntegerType scalar) { return SmallVector<int64_t >({1 }); })
197+ .Default ([](Type) { return SmallVector<int64_t >(); });
198+
199+ Type elType = llvm::TypeSwitch<Type, Type>(inputType)
200+ .Case ([](RankedTensorType tensor) {
201+ return tensor.getElementType ();
202+ })
203+ .Case ([](FloatType scalar) { return scalar; })
204+ .Case ([](IntegerType scalar) { return scalar; })
205+ .Default ([](Type) { return Type (); });
206+
207+ // If we were not able to result the information skip.
208+ if (inputShape.empty () || !elType)
209+ continue ;
187210
211+ auto &map = newMap[i];
188212 // / Nothing to do if map is already an identity.
189213 if (map.isIdentity ())
190214 continue ;
@@ -197,7 +221,7 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
197221 // / rule: dim(result, i) = dim(input, permutation[i])
198222 SmallVector<int64_t > transposedShape (map.getNumResults ());
199223 for (int64_t i = 0 ; i < map.getNumResults (); ++i)
200- transposedShape[i] = inputRTType. getShape () [permutation[i]];
224+ transposedShape[i] = inputShape [permutation[i]];
201225
202226 Value emptyTensor =
203227 rewriter.create <tensor::EmptyOp>(loc, transposedShape, elType);
@@ -211,13 +235,23 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite(
211235 // Does it require broadcast?
212236 if (!broadcastedDims.empty ()) {
213237 assert (broadcastedDims.size () && " should have non size broadcast" );
214- Value emptyTensor = rewriter. create <tensor::EmptyOp>(
215- loc, outputShape, inputRTType. getElementType () );
238+ Value emptyTensor =
239+ rewriter. create <tensor::EmptyOp>( loc, outputShape, elType );
216240
217- auto broadcastOp = rewriter.create <linalg::BroadcastOp>(
218- loc, newInitValues[i], emptyTensor, broadcastedDims);
241+ Value source = newInitValues[i];
242+ Value result;
243+ // If a scalar is being broadcasted we can simply use a fill operation.
244+ if (source.getType ().isIntOrFloat ()) {
245+ result = rewriter.create <linalg::FillOp>(loc, source, emptyTensor)
246+ ->getResult (0 );
247+ } else {
248+ result = rewriter
249+ .create <linalg::BroadcastOp>(loc, source, emptyTensor,
250+ broadcastedDims)
251+ ->getResult (0 );
252+ }
219253
220- newInitValues[i] = broadcastOp-> getResult ( 0 ) ;
254+ newInitValues[i] = result ;
221255 isChanged = true ;
222256 }
223257 newMap[i] = rewriter.getMultiDimIdentityMap (map.getNumDims ());
0 commit comments