Skip to content

Commit 60bdbf9

Browse files
fabianmcgjoker-eph
andauthored
Apply suggestion from @joker-eph
Co-authored-by: Mehdi Amini <[email protected]>
1 parent efd30ca commit 60bdbf9

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -355,34 +355,32 @@ LogicalResult PtrAddOp::inferReturnTypes(
355355
Type offsetType = operands[1].getType();
356356

357357
// If neither are shaped types, result is same as base type.
358-
if (!isa<ShapedType>(baseType) && !isa<ShapedType>(offsetType)) {
358+
auto offTy = dyn_cast<ShapedType>(offsetType);
359+
if (!offTy) {
360+
// If the offset isn't shaped, the result is always the base type.
359361
inferredReturnTypes.push_back(baseType);
360362
return success();
361363
}
362-
363-
// Handle cases with shaped types.
364-
if (auto baseTy = dyn_cast<ShapedType>(baseType)) {
365-
// If both shaped, they must have the same shape.
366-
if (auto offTy = dyn_cast<ShapedType>(offsetType)) {
367-
if (offTy.getShape() != baseTy.getShape()) {
368-
if (location)
369-
mlir::emitError(*location) << "shapes of base and offset must match";
370-
return failure();
371-
}
372-
// Make sure they are the same kind of shaped type.
373-
if (baseType.getTypeID() != offsetType.getTypeID()) {
374-
if (location)
375-
mlir::emitError(*location) << "the shaped containers type must match";
376-
return failure();
377-
}
378-
}
379-
inferredReturnTypes.push_back(baseType);
380-
return success();
364+
auto baseTy = dyn_cast<ShapedType>(baseType);
365+
if (!baseTy) {
366+
// Base isn't shaped, but offset is, use the ShapedType from offset with the base pointer as element type.
367+
inferredReturnTypes.push_back(offsetShapedType.clone(baseType));
368+
return success();
381369
}
382370

383-
// Base is scalar, offset is shaped.
384-
auto offsetShapedType = cast<ShapedType>(offsetType);
385-
inferredReturnTypes.push_back(offsetShapedType.clone(baseType));
371+
// Both are shaped, their shape must match.
372+
if (offTy.getShape() != baseTy.getShape()) {
373+
if (location)
374+
mlir::emitError(*location) << "shapes of base and offset must match";
375+
return failure();
376+
}
377+
// Make sure they are the same kind of shaped type.
378+
if (baseType.getTypeID() != offsetType.getTypeID()) {
379+
if (location)
380+
mlir::emitError(*location) << "the shaped containers type must match";
381+
return failure();
382+
}
383+
inferredReturnTypes.push_back(baseType);
386384
return success();
387385
}
388386

0 commit comments

Comments
 (0)