@@ -1176,8 +1176,11 @@ template <typename OpTy>
1176
1176
static LogicalResult reduceMatchAndRewriteHelper (OpTy op, uint64_t axis,
1177
1177
PatternRewriter &rewriter) {
1178
1178
auto loc = op->getLoc ();
1179
- auto inputTy = cast<ShapedType>(op->getOperand (0 ).getType ());
1180
- auto resultTy = cast<ShapedType>(op->getResult (0 ).getType ());
1179
+ auto inputTy = dyn_cast<RankedTensorType>(op->getOperand (0 ).getType ());
1180
+ auto resultTy = dyn_cast<RankedTensorType>(op->getResult (0 ).getType ());
1181
+ if (!inputTy || !resultTy)
1182
+ return rewriter.notifyMatchFailure (op, " unranked tensors not supported" );
1183
+
1181
1184
auto elementTy = resultTy.getElementType ();
1182
1185
Value input = op->getOperand (0 );
1183
1186
@@ -2380,11 +2383,9 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2380
2383
auto input = adaptor.getOperands ()[0 ];
2381
2384
auto indices = adaptor.getOperands ()[1 ];
2382
2385
2383
- auto valuesTy =
2384
- dyn_cast_or_null<RankedTensorType>(op.getValues ().getType ());
2385
- auto resultTy = cast<ShapedType>(op.getType ());
2386
-
2387
- if (!valuesTy)
2386
+ auto valuesTy = dyn_cast<RankedTensorType>(op.getValues ().getType ());
2387
+ auto resultTy = dyn_cast<RankedTensorType>(op.getType ());
2388
+ if (!valuesTy || !resultTy)
2388
2389
return rewriter.notifyMatchFailure (op, " unranked tensors not supported" );
2389
2390
2390
2391
auto dynamicDims = inferDynamicDimsForGather (
0 commit comments