Skip to content

Commit 9808285

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Remove redundant mlir::cast.
PiperOrigin-RevId: 835200204
1 parent 1b9951e commit 9808285

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,8 @@ llvm::LogicalResult BroadcastInDimOp::verify() {
522522
return emitOpError(llvm::formatv(params...));
523523
};
524524

525-
auto operand_type = mlir::cast<mlir::VectorType>(getOperand().getType());
526-
auto result_type = mlir::cast<mlir::VectorType>(getResult().getType());
525+
mlir::VectorType operand_type = getOperand().getType();
526+
mlir::VectorType result_type = getResult().getType();
527527

528528
if (operand_type.getRank() == 0) {
529529
return error("The input vector must have rank > 0.");
@@ -559,15 +559,12 @@ llvm::LogicalResult BroadcastInDimOp::verify() {
559559
}
560560

561561
llvm::LogicalResult ReturnOp::verify() {
562-
auto custom_primitive_op =
563-
mlir::cast<CustomPrimitiveOp>((*this)->getParentOp());
564-
565562
// The operand number and types must match the custom primitive signature.
566-
const auto& results = custom_primitive_op->getResultTypes();
563+
const auto& results = getParentOp()->getResultTypes();
567564
if (getNumOperands() != results.size())
568565
return emitOpError("has ")
569566
<< getNumOperands() << " operands, but enclosing custom_primitive (@"
570-
<< custom_primitive_op->getName() << ") returns " << results.size();
567+
<< getParentOp()->getName() << ") returns " << results.size();
571568

572569
for (unsigned i = 0, e = results.size(); i != e; ++i)
573570
if (getOperand(i).getType() != results[i])
@@ -576,7 +573,7 @@ llvm::LogicalResult ReturnOp::verify() {
576573
<< ") doesn't match the result type (" << results[i]
577574
<< ")"
578575
<< " in custom_primitive @"
579-
<< custom_primitive_op->getName();
576+
<< getParentOp()->getName();
580577

581578
return llvm::success();
582579
}

0 commit comments

Comments
 (0)