@@ -481,20 +481,37 @@ static void printDimAndSymbolList(Operation::operand_iterator begin,
481481 printer << ' [' << operands.drop_front (numDims) << ' ]' ;
482482}
483483
484- // / Parses dimension and symbol list and returns true if parsing failed.
485- ParseResult mlir::affine::parseDimAndSymbolList (
486- OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
487- SmallVector<OpAsmParser::UnresolvedOperand, 8 > opInfos;
484+ // / Parse dimension and symbol list, but not resolve yet, as we may not know the
485+ // / operands types.
486+ static ParseResult parseDimAndSymbolListImpl (
487+ OpAsmParser &parser,
488+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &opInfos,
489+ unsigned &numDims) {
488490 if (parser.parseOperandList (opInfos, OpAsmParser::Delimiter::Paren))
489491 return failure ();
492+
490493 // Store number of dimensions for validation by caller.
491494 numDims = opInfos.size ();
492495
493496 // Parse the optional symbol operands.
497+ if (parser.parseOperandList (opInfos, OpAsmParser::Delimiter::OptionalSquare))
498+ return failure ();
499+
500+ return success ();
501+ }
502+
503+ // / Parses dimension and symbol list and returns true if parsing failed.
504+ ParseResult mlir::affine::parseDimAndSymbolList (
505+ OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
506+ SmallVector<OpAsmParser::UnresolvedOperand, 8 > opInfos;
507+ if (parseDimAndSymbolListImpl (parser, opInfos, numDims))
508+ return failure ();
509+
494510 auto indexTy = parser.getBuilder ().getIndexType ();
495- return failure (parser.parseOperandList (
496- opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
497- parser.resolveOperands (opInfos, indexTy, operands));
511+ if (parser.resolveOperands (opInfos, indexTy, operands))
512+ return failure ();
513+
514+ return success ();
498515}
499516
500517// / Utility function to verify that a set of operands are valid dimension and
@@ -528,14 +545,25 @@ AffineValueMap AffineApplyOp::getAffineValueMap() {
528545
529546ParseResult AffineApplyOp::parse (OpAsmParser &parser, OperationState &result) {
530547 auto &builder = parser.getBuilder ();
531- auto indexTy = builder.getIndexType ();
532548
533549 AffineMapAttr mapAttr;
534550 unsigned numDims;
551+ SmallVector<OpAsmParser::UnresolvedOperand, 8 > opInfos;
535552 if (parser.parseAttribute (mapAttr, " map" , result.attributes ) ||
536- parseDimAndSymbolList (parser, result. operands , numDims) ||
553+ parseDimAndSymbolListImpl (parser, opInfos , numDims) ||
537554 parser.parseOptionalAttrDict (result.attributes ))
538555 return failure ();
556+
557+ Type type;
558+ if (parser.parseOptionalColon ()) {
559+ type = builder.getIndexType ();
560+ } else if (parser.parseType (type)) {
561+ return failure ();
562+ }
563+
564+ if (parser.resolveOperands (opInfos, type, result.operands ))
565+ return failure ();
566+
539567 auto map = mapAttr.getValue ();
540568
541569 if (map.getNumDims () != numDims ||
@@ -544,7 +572,7 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
544572 " dimension or symbol index mismatch" );
545573 }
546574
547- result.types .append (map.getNumResults (), indexTy );
575+ result.types .append (map.getNumResults (), type );
548576 return success ();
549577}
550578
@@ -553,9 +581,18 @@ void AffineApplyOp::print(OpAsmPrinter &p) {
553581 printDimAndSymbolList (operand_begin (), operand_end (),
554582 getAffineMap ().getNumDims (), p);
555583 p.printOptionalAttrDict ((*this )->getAttrs (), /* elidedAttrs=*/ {" map" });
584+ Type resType = getType ();
585+ if (!isa<IndexType>(resType))
586+ p << " :" << resType;
556587}
557588
558589LogicalResult AffineApplyOp::verify () {
590+ // Check all operand and result types are the same.
591+ // We cannot use `SameOperandsAndResultType` as it expects at least 1 operand.
592+ if (!llvm::all_equal (
593+ llvm::concat<Type>(getOperandTypes (), (*this )->getResultTypes ())))
594+ return emitOpError (" requires the same type for all operands and results" );
595+
559596 // Check input and output dimensions match.
560597 AffineMap affineMap = getMap ();
561598
0 commit comments