@@ -472,6 +472,108 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472472 p << stringifyClauseOrderKind (order.getValue ());
473473}
474474
475+ // ===----------------------------------------------------------------------===//
476+ // Parser and printer for grainsize Clause
477+ // ===----------------------------------------------------------------------===//
478+
479+ // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
480+ static ParseResult
481+ parseGrainsizeClause (OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
482+ std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
483+ Type &grainsizeType) {
484+ SMLoc loc = parser.getCurrentLocation ();
485+ StringRef enumStr;
486+
487+ if (succeeded (parser.parseOptionalKeyword (&enumStr))) {
488+ if (std::optional<ClauseGrainsizeType> enumValue =
489+ symbolizeClauseGrainsizeType (enumStr)) {
490+ grainsizeMod =
491+ ClauseGrainsizeTypeAttr::get (parser.getContext (), *enumValue);
492+ if (parser.parseColon ())
493+ return failure ();
494+ } else {
495+ return parser.emitError (loc, " invalid grainsize modifier : '" )
496+ << enumStr << " '" ;
497+ }
498+ }
499+
500+ OpAsmParser::UnresolvedOperand operand;
501+ if (succeeded (parser.parseOperand (operand))) {
502+ grainsize = operand;
503+ } else {
504+ return parser.emitError (parser.getCurrentLocation ())
505+ << " expected grainsize operand" ;
506+ }
507+
508+ if (grainsize.has_value ()) {
509+ if (parser.parseColonType (grainsizeType))
510+ return failure ();
511+ }
512+
513+ return success ();
514+ }
515+
516+ static void printGrainsizeClause (OpAsmPrinter &p, Operation *op,
517+ ClauseGrainsizeTypeAttr grainsizeMod,
518+ Value grainsize, mlir::Type grainsizeType) {
519+ if (grainsizeMod)
520+ p << stringifyClauseGrainsizeType (grainsizeMod.getValue ()) << " : " ;
521+
522+ if (grainsize)
523+ p << grainsize << " : " << grainsizeType;
524+ }
525+
526+ // ===----------------------------------------------------------------------===//
527+ // Parser and printer for num_tasks Clause
528+ // ===----------------------------------------------------------------------===//
529+
530+ // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
531+ static ParseResult
532+ parseNumTasksClause (OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
533+ std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
534+ Type &numTasksType) {
535+ SMLoc loc = parser.getCurrentLocation ();
536+ StringRef enumStr;
537+
538+ if (succeeded (parser.parseOptionalKeyword (&enumStr))) {
539+ if (std::optional<ClauseNumTasksType> enumValue =
540+ symbolizeClauseNumTasksType (enumStr)) {
541+ numTasksMod =
542+ ClauseNumTasksTypeAttr::get (parser.getContext (), *enumValue);
543+ if (parser.parseColon ())
544+ return failure ();
545+ } else {
546+ return parser.emitError (loc, " invalid numTasks modifier : '" )
547+ << enumStr << " '" ;
548+ }
549+ }
550+
551+ OpAsmParser::UnresolvedOperand operand;
552+ if (succeeded (parser.parseOperand (operand))) {
553+ numTasks = operand;
554+ } else {
555+ return parser.emitError (parser.getCurrentLocation ())
556+ << " expected num_tasks operand" ;
557+ }
558+
559+ if (numTasks.has_value ()) {
560+ if (parser.parseColonType (numTasksType))
561+ return failure ();
562+ }
563+
564+ return success ();
565+ }
566+
567+ static void printNumTasksClause (OpAsmPrinter &p, Operation *op,
568+ ClauseNumTasksTypeAttr numTasksMod,
569+ Value numTasks, mlir::Type numTasksType) {
570+ if (numTasksMod)
571+ p << stringifyClauseNumTasksType (numTasksMod.getValue ()) << " : " ;
572+
573+ if (numTasks)
574+ p << numTasks << " : " << numTasksType;
575+ }
576+
475577// ===----------------------------------------------------------------------===//
476578// Parsers for operations including clauses that define entry block arguments.
477579// ===----------------------------------------------------------------------===//
@@ -2593,15 +2695,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
25932695 const TaskloopOperands &clauses) {
25942696 MLIRContext *ctx = builder.getContext ();
25952697 // TODO Store clauses in op: privateVars, privateSyms.
2596- TaskloopOp::build (
2597- builder, state, clauses.allocateVars , clauses.allocatorVars ,
2598- clauses.final , clauses.grainsize , clauses.ifExpr , clauses.inReductionVars ,
2599- makeDenseBoolArrayAttr (ctx, clauses.inReductionByref ),
2600- makeArrayAttr (ctx, clauses.inReductionSyms ), clauses.mergeable ,
2601- clauses.nogroup , clauses.numTasks , clauses.priority , /* private_vars=*/ {},
2602- /* private_syms=*/ nullptr , clauses.reductionMod , clauses.reductionVars ,
2603- makeDenseBoolArrayAttr (ctx, clauses.reductionByref ),
2604- makeArrayAttr (ctx, clauses.reductionSyms ), clauses.untied );
2698+ TaskloopOp::build (builder, state, clauses.allocateVars , clauses.allocatorVars ,
2699+ clauses.final , clauses.grainsizeMod , clauses.grainsize ,
2700+ clauses.ifExpr , clauses.inReductionVars ,
2701+ makeDenseBoolArrayAttr (ctx, clauses.inReductionByref ),
2702+ makeArrayAttr (ctx, clauses.inReductionSyms ),
2703+ clauses.mergeable , clauses.nogroup , clauses.numTasksMod ,
2704+ clauses.numTasks , clauses.priority , /* private_vars=*/ {},
2705+ /* private_syms=*/ nullptr , clauses.reductionMod ,
2706+ clauses.reductionVars ,
2707+ makeDenseBoolArrayAttr (ctx, clauses.reductionByref ),
2708+ makeArrayAttr (ctx, clauses.reductionSyms ), clauses.untied );
26052709}
26062710
26072711SmallVector<Value> TaskloopOp::getAllReductionVars () {
0 commit comments