@@ -355,34 +355,32 @@ LogicalResult PtrAddOp::inferReturnTypes(
355355 Type offsetType = operands[1 ].getType ();
356356
357357 // If neither are shaped types, result is same as base type.
358- if (!isa<ShapedType>(baseType) && !isa<ShapedType>(offsetType)) {
358+ auto offTy = dyn_cast<ShapedType>(offsetType);
359+ if (!offTy) {
360+ // If the offset isn't shaped, the result is always the base type.
359361 inferredReturnTypes.push_back (baseType);
360362 return success ();
361363 }
362-
363- // Handle cases with shaped types.
364- if (auto baseTy = dyn_cast<ShapedType>(baseType)) {
365- // If both shaped, they must have the same shape.
366- if (auto offTy = dyn_cast<ShapedType>(offsetType)) {
367- if (offTy.getShape () != baseTy.getShape ()) {
368- if (location)
369- mlir::emitError (*location) << " shapes of base and offset must match" ;
370- return failure ();
371- }
372- // Make sure they are the same kind of shaped type.
373- if (baseType.getTypeID () != offsetType.getTypeID ()) {
374- if (location)
375- mlir::emitError (*location) << " the shaped containers type must match" ;
376- return failure ();
377- }
378- }
379- inferredReturnTypes.push_back (baseType);
380- return success ();
364+ auto baseTy = dyn_cast<ShapedType>(baseType);
365+ if (!baseTy) {
366+ // Base isn't shaped, but offset is, use the ShapedType from offset with the base pointer as element type.
367+ inferredReturnTypes.push_back (offsetShapedType.clone (baseType));
368+ return success ();
381369 }
382370
383- // Base is scalar, offset is shaped.
384- auto offsetShapedType = cast<ShapedType>(offsetType);
385- inferredReturnTypes.push_back (offsetShapedType.clone (baseType));
371+ // Both are shaped, their shape must match.
372+ if (offTy.getShape () != baseTy.getShape ()) {
373+ if (location)
374+ mlir::emitError (*location) << " shapes of base and offset must match" ;
375+ return failure ();
376+ }
377+ // Make sure they are the same kind of shaped type.
378+ if (baseType.getTypeID () != offsetType.getTypeID ()) {
379+ if (location)
380+ mlir::emitError (*location) << " the shaped containers type must match" ;
381+ return failure ();
382+ }
383+ inferredReturnTypes.push_back (baseType);
386384 return success ();
387385}
388386
0 commit comments