@@ -4748,6 +4748,167 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns(
47484748 patterns.add <SimplifyBoxTotalElementsOp>(context);
47494749}
47504750
4751+ // ===----------------------------------------------------------------------===//
4752+ // DoConcurrentOp
4753+ // ===----------------------------------------------------------------------===//
4754+
4755+ llvm::LogicalResult fir::DoConcurrentOp::verify () {
4756+ mlir::Block *body = getBody ();
4757+
4758+ if (body->empty ())
4759+ return emitOpError (" body cannot be empty" );
4760+
4761+ if (!body->mightHaveTerminator () ||
4762+ !mlir::isa<fir::DoConcurrentLoopOp>(body->getTerminator ()))
4763+ return emitOpError (" must be terminated by 'fir.do_concurrent.loop'" );
4764+
4765+ return mlir::success ();
4766+ }
4767+
4768+ // ===----------------------------------------------------------------------===//
4769+ // DoConcurrentLoopOp
4770+ // ===----------------------------------------------------------------------===//
4771+
4772+ mlir::ParseResult fir::DoConcurrentLoopOp::parse (mlir::OpAsmParser &parser,
4773+ mlir::OperationState &result) {
4774+ auto &builder = parser.getBuilder ();
4775+ // Parse an opening `(` followed by induction variables followed by `)`
4776+ llvm::SmallVector<mlir::OpAsmParser::Argument, 4 > ivs;
4777+ if (parser.parseArgumentList (ivs, mlir::OpAsmParser::Delimiter::Paren))
4778+ return mlir::failure ();
4779+
4780+ // Parse loop bounds.
4781+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > lower;
4782+ if (parser.parseEqual () ||
4783+ parser.parseOperandList (lower, ivs.size (),
4784+ mlir::OpAsmParser::Delimiter::Paren) ||
4785+ parser.resolveOperands (lower, builder.getIndexType (), result.operands ))
4786+ return mlir::failure ();
4787+
4788+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > upper;
4789+ if (parser.parseKeyword (" to" ) ||
4790+ parser.parseOperandList (upper, ivs.size (),
4791+ mlir::OpAsmParser::Delimiter::Paren) ||
4792+ parser.resolveOperands (upper, builder.getIndexType (), result.operands ))
4793+ return mlir::failure ();
4794+
4795+ // Parse step values.
4796+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > steps;
4797+ if (parser.parseKeyword (" step" ) ||
4798+ parser.parseOperandList (steps, ivs.size (),
4799+ mlir::OpAsmParser::Delimiter::Paren) ||
4800+ parser.resolveOperands (steps, builder.getIndexType (), result.operands ))
4801+ return mlir::failure ();
4802+
4803+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
4804+ llvm::SmallVector<mlir::Type> reduceArgTypes;
4805+ if (succeeded (parser.parseOptionalKeyword (" reduce" ))) {
4806+ // Parse reduction attributes and variables.
4807+ llvm::SmallVector<fir::ReduceAttr> attributes;
4808+ if (failed (parser.parseCommaSeparatedList (
4809+ mlir::AsmParser::Delimiter::Paren, [&]() {
4810+ if (parser.parseAttribute (attributes.emplace_back ()) ||
4811+ parser.parseArrow () ||
4812+ parser.parseOperand (reduceOperands.emplace_back ()) ||
4813+ parser.parseColonType (reduceArgTypes.emplace_back ()))
4814+ return mlir::failure ();
4815+ return mlir::success ();
4816+ })))
4817+ return mlir::failure ();
4818+ // Resolve input operands.
4819+ for (auto operand_type : llvm::zip (reduceOperands, reduceArgTypes))
4820+ if (parser.resolveOperand (std::get<0 >(operand_type),
4821+ std::get<1 >(operand_type), result.operands ))
4822+ return mlir::failure ();
4823+ llvm::SmallVector<mlir::Attribute> arrayAttr (attributes.begin (),
4824+ attributes.end ());
4825+ result.addAttribute (getReduceAttrsAttrName (result.name ),
4826+ builder.getArrayAttr (arrayAttr));
4827+ }
4828+
4829+ // Now parse the body.
4830+ mlir::Region *body = result.addRegion ();
4831+ for (auto &iv : ivs)
4832+ iv.type = builder.getIndexType ();
4833+ if (parser.parseRegion (*body, ivs))
4834+ return mlir::failure ();
4835+
4836+ // Set `operandSegmentSizes` attribute.
4837+ result.addAttribute (DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
4838+ builder.getDenseI32ArrayAttr (
4839+ {static_cast <int32_t >(lower.size ()),
4840+ static_cast <int32_t >(upper.size ()),
4841+ static_cast <int32_t >(steps.size ()),
4842+ static_cast <int32_t >(reduceOperands.size ())}));
4843+
4844+ // Parse attributes.
4845+ if (parser.parseOptionalAttrDict (result.attributes ))
4846+ return mlir::failure ();
4847+
4848+ return mlir::success ();
4849+ }
4850+
4851+ void fir::DoConcurrentLoopOp::print (mlir::OpAsmPrinter &p) {
4852+ p << " (" << getBody ()->getArguments () << " ) = (" << getLowerBound ()
4853+ << " ) to (" << getUpperBound () << " ) step (" << getStep () << " )" ;
4854+
4855+ if (!getReduceOperands ().empty ()) {
4856+ p << " reduce(" ;
4857+ auto attrs = getReduceAttrsAttr ();
4858+ auto operands = getReduceOperands ();
4859+ llvm::interleaveComma (llvm::zip (attrs, operands), p, [&](auto it) {
4860+ p << std::get<0 >(it) << " -> " << std::get<1 >(it) << " : "
4861+ << std::get<1 >(it).getType ();
4862+ });
4863+ p << ' )' ;
4864+ }
4865+
4866+ p << ' ' ;
4867+ p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false );
4868+ p.printOptionalAttrDict (
4869+ (*this )->getAttrs (),
4870+ /* elidedAttrs=*/ {DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
4871+ DoConcurrentLoopOp::getReduceAttrsAttrName ()});
4872+ }
4873+
4874+ llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions () {
4875+ return {&getRegion ()};
4876+ }
4877+
4878+ llvm::LogicalResult fir::DoConcurrentLoopOp::verify () {
4879+ mlir::Operation::operand_range lbValues = getLowerBound ();
4880+ mlir::Operation::operand_range ubValues = getUpperBound ();
4881+ mlir::Operation::operand_range stepValues = getStep ();
4882+
4883+ if (lbValues.empty ())
4884+ return emitOpError (
4885+ " needs at least one tuple element for lowerBound, upperBound and step" );
4886+
4887+ if (lbValues.size () != ubValues.size () ||
4888+ ubValues.size () != stepValues.size ())
4889+ return emitOpError (" different number of tuple elements for lowerBound, "
4890+ " upperBound or step" );
4891+
4892+ // Check that the body defines the same number of block arguments as the
4893+ // number of tuple elements in step.
4894+ mlir::Block *body = getBody ();
4895+ if (body->getNumArguments () != stepValues.size ())
4896+ return emitOpError () << " expects the same number of induction variables: "
4897+ << body->getNumArguments ()
4898+ << " as bound and step values: " << stepValues.size ();
4899+ for (auto arg : body->getArguments ())
4900+ if (!arg.getType ().isIndex ())
4901+ return emitOpError (
4902+ " expects arguments for the induction variable to be of index type" );
4903+
4904+ auto reduceAttrs = getReduceAttrsAttr ();
4905+ if (getNumReduceOperands () != (reduceAttrs ? reduceAttrs.size () : 0 ))
4906+ return emitOpError (
4907+ " mismatch in number of reduction variables and reduction attributes" );
4908+
4909+ return mlir::success ();
4910+ }
4911+
47514912// ===----------------------------------------------------------------------===//
47524913// FIROpsDialect
47534914// ===----------------------------------------------------------------------===//
0 commit comments