@@ -522,8 +522,8 @@ llvm::LogicalResult BroadcastInDimOp::verify() {
522522 return emitOpError (llvm::formatv (params...));
523523 };
524524
525- auto operand_type = mlir::cast<mlir::VectorType>( getOperand ().getType () );
526- auto result_type = mlir::cast<mlir::VectorType>( getResult ().getType () );
525+ mlir::VectorType operand_type = getOperand ().getType ();
526+ mlir::VectorType result_type = getResult ().getType ();
527527
528528 if (operand_type.getRank () == 0 ) {
529529 return error (" The input vector must have rank > 0." );
@@ -559,15 +559,12 @@ llvm::LogicalResult BroadcastInDimOp::verify() {
559559}
560560
561561llvm::LogicalResult ReturnOp::verify () {
562- auto custom_primitive_op =
563- mlir::cast<CustomPrimitiveOp>((*this )->getParentOp ());
564-
565562 // The operand number and types must match the custom primitive signature.
566- const auto & results = custom_primitive_op ->getResultTypes ();
563+ const auto & results = getParentOp () ->getResultTypes ();
567564 if (getNumOperands () != results.size ())
568565 return emitOpError (" has " )
569566 << getNumOperands () << " operands, but enclosing custom_primitive (@"
570- << custom_primitive_op ->getName () << " ) returns " << results.size ();
567+ << getParentOp () ->getName () << " ) returns " << results.size ();
571568
572569 for (unsigned i = 0 , e = results.size (); i != e; ++i)
573570 if (getOperand (i).getType () != results[i])
@@ -576,7 +573,7 @@ llvm::LogicalResult ReturnOp::verify() {
576573 << " ) doesn't match the result type (" << results[i]
577574 << " )"
578575 << " in custom_primitive @"
579- << custom_primitive_op ->getName ();
576+ << getParentOp () ->getName ();
580577
581578 return llvm::success ();
582579}
0 commit comments