@@ -491,20 +491,37 @@ static void printDimAndSymbolList(Operation::operand_iterator begin,
491491 printer << ' [' << operands.drop_front (numDims) << ' ]' ;
492492}
493493
494- // / Parses dimension and symbol list and returns true if parsing failed.
495- ParseResult mlir::affine::parseDimAndSymbolList (
496- OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
497- SmallVector<OpAsmParser::UnresolvedOperand, 8 > opInfos;
494+ // / Parse dimension and symbol list, but not resolve yet, as we may not know the
495+ // / operands types.
496+ static ParseResult parseDimAndSymbolListImpl (
497+ OpAsmParser &parser,
498+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &opInfos,
499+ unsigned &numDims) {
498500 if (parser.parseOperandList (opInfos, OpAsmParser::Delimiter::Paren))
499501 return failure ();
502+
500503 // Store number of dimensions for validation by caller.
501504 numDims = opInfos.size ();
502505
503506 // Parse the optional symbol operands.
507+ if (parser.parseOperandList (opInfos, OpAsmParser::Delimiter::OptionalSquare))
508+ return failure ();
509+
510+ return success ();
511+ }
512+
513+ // / Parses dimension and symbol list and returns true if parsing failed.
514+ ParseResult mlir::affine::parseDimAndSymbolList (
515+ OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
516+ SmallVector<OpAsmParser::UnresolvedOperand, 8 > opInfos;
517+ if (parseDimAndSymbolListImpl (parser, opInfos, numDims))
518+ return failure ();
519+
504520 auto indexTy = parser.getBuilder ().getIndexType ();
505- return failure (parser.parseOperandList (
506- opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
507- parser.resolveOperands (opInfos, indexTy, operands));
521+ if (parser.resolveOperands (opInfos, indexTy, operands))
522+ return failure ();
523+
524+ return success ();
508525}
509526
510527// / Utility function to verify that a set of operands are valid dimension and
@@ -538,14 +555,25 @@ AffineValueMap AffineApplyOp::getAffineValueMap() {
538555
539556ParseResult AffineApplyOp::parse (OpAsmParser &parser, OperationState &result) {
540557 auto &builder = parser.getBuilder ();
541- auto indexTy = builder.getIndexType ();
542558
543559 AffineMapAttr mapAttr;
544560 unsigned numDims;
561+ SmallVector<OpAsmParser::UnresolvedOperand, 8 > opInfos;
545562 if (parser.parseAttribute (mapAttr, " map" , result.attributes ) ||
546- parseDimAndSymbolList (parser, result. operands , numDims) ||
563+ parseDimAndSymbolListImpl (parser, opInfos , numDims) ||
547564 parser.parseOptionalAttrDict (result.attributes ))
548565 return failure ();
566+
567+ Type type;
568+ if (parser.parseOptionalColon ()) {
569+ type = builder.getIndexType ();
570+ } else if (parser.parseType (type)) {
571+ return failure ();
572+ }
573+
574+ if (parser.resolveOperands (opInfos, type, result.operands ))
575+ return failure ();
576+
549577 auto map = mapAttr.getValue ();
550578
551579 if (map.getNumDims () != numDims ||
@@ -554,7 +582,7 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
554582 " dimension or symbol index mismatch" );
555583 }
556584
557- result.types .append (map.getNumResults (), indexTy );
585+ result.types .append (map.getNumResults (), type );
558586 return success ();
559587}
560588
@@ -563,9 +591,18 @@ void AffineApplyOp::print(OpAsmPrinter &p) {
563591 printDimAndSymbolList (operand_begin (), operand_end (),
564592 getAffineMap ().getNumDims (), p);
565593 p.printOptionalAttrDict ((*this )->getAttrs (), /* elidedAttrs=*/ {" map" });
594+ Type resType = getType ();
595+ if (!isa<IndexType>(resType))
596+ p << " :" << resType;
566597}
567598
568599LogicalResult AffineApplyOp::verify () {
600+ // Check all operand and result types are the same.
601+ // We cannot use `SameOperandsAndResultType` as it expects at least 1 operand.
602+ if (!llvm::all_equal (
603+ llvm::concat<Type>(getOperandTypes (), (*this )->getResultTypes ())))
604+ return emitOpError (" requires the same type for all operands and results" );
605+
569606 // Check input and output dimensions match.
570607 AffineMap affineMap = getMap ();
571608
0 commit comments