@@ -472,6 +472,99 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
472472 p << stringifyClauseOrderKind (order.getValue ());
473473}
474474
475+ template <typename ClauseTypeAttr, typename ClauseType>
476+ static ParseResult
477+ parseGranularityClause (OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478+ std::optional<OpAsmParser::UnresolvedOperand> &operand,
479+ Type &operandType,
480+ std::optional<ClauseType> (*symbolizeClause)(StringRef),
481+ StringRef clauseName) {
482+ StringRef enumStr;
483+ if (succeeded (parser.parseOptionalKeyword (&enumStr))) {
484+ if (std::optional<ClauseType> enumValue = symbolizeClause (enumStr)) {
485+ prescriptiveness = ClauseTypeAttr::get (parser.getContext (), *enumValue);
486+ if (parser.parseComma ())
487+ return failure ();
488+ } else {
489+ return parser.emitError (parser.getCurrentLocation ())
490+ << " invalid " << clauseName << " modifier : '" << enumStr << " '" ;
491+ ;
492+ }
493+ }
494+
495+ OpAsmParser::UnresolvedOperand var;
496+ if (succeeded (parser.parseOperand (var))) {
497+ operand = var;
498+ } else {
499+ return parser.emitError (parser.getCurrentLocation ())
500+ << " expected " << clauseName << " operand" ;
501+ }
502+
503+ if (operand.has_value ()) {
504+ if (parser.parseColonType (operandType))
505+ return failure ();
506+ }
507+
508+ return success ();
509+ }
510+
511+ template <typename ClauseTypeAttr, typename ClauseType>
512+ static void
513+ printGranularityClause (OpAsmPrinter &p, Operation *op,
514+ ClauseTypeAttr prescriptiveness, Value operand,
515+ mlir::Type operandType,
516+ StringRef (*stringifyClauseType)(ClauseType)) {
517+
518+ if (prescriptiveness)
519+ p << stringifyClauseType (prescriptiveness.getValue ()) << " , " ;
520+
521+ if (operand)
522+ p << operand << " : " << operandType;
523+ }
524+
525+ // ===----------------------------------------------------------------------===//
526+ // Parser and printer for grainsize Clause
527+ // ===----------------------------------------------------------------------===//
528+
529+ // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530+ static ParseResult
531+ parseGrainsizeClause (OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532+ std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533+ Type &grainsizeType) {
534+ return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535+ parser, grainsizeMod, grainsize, grainsizeType,
536+ &symbolizeClauseGrainsizeType, " grainsize" );
537+ }
538+
539+ static void printGrainsizeClause (OpAsmPrinter &p, Operation *op,
540+ ClauseGrainsizeTypeAttr grainsizeMod,
541+ Value grainsize, mlir::Type grainsizeType) {
542+ printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543+ p, op, grainsizeMod, grainsize, grainsizeType,
544+ &stringifyClauseGrainsizeType);
545+ }
546+
547+ // ===----------------------------------------------------------------------===//
548+ // Parser and printer for num_tasks Clause
549+ // ===----------------------------------------------------------------------===//
550+
551+ // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
552+ static ParseResult
553+ parseNumTasksClause (OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
554+ std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
555+ Type &numTasksType) {
556+ return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557+ parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558+ " num_tasks" );
559+ }
560+
561+ static void printNumTasksClause (OpAsmPrinter &p, Operation *op,
562+ ClauseNumTasksTypeAttr numTasksMod,
563+ Value numTasks, mlir::Type numTasksType) {
564+ printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565+ p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
566+ }
567+
475568// ===----------------------------------------------------------------------===//
476569// Parsers for operations including clauses that define entry block arguments.
477570// ===----------------------------------------------------------------------===//
@@ -2593,15 +2686,17 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
25932686 const TaskloopOperands &clauses) {
25942687 MLIRContext *ctx = builder.getContext ();
25952688 // TODO Store clauses in op: privateVars, privateSyms.
2596- TaskloopOp::build (
2597- builder, state, clauses.allocateVars , clauses.allocatorVars ,
2598- clauses.final , clauses.grainsize , clauses.ifExpr , clauses.inReductionVars ,
2599- makeDenseBoolArrayAttr (ctx, clauses.inReductionByref ),
2600- makeArrayAttr (ctx, clauses.inReductionSyms ), clauses.mergeable ,
2601- clauses.nogroup , clauses.numTasks , clauses.priority , /* private_vars=*/ {},
2602- /* private_syms=*/ nullptr , clauses.reductionMod , clauses.reductionVars ,
2603- makeDenseBoolArrayAttr (ctx, clauses.reductionByref ),
2604- makeArrayAttr (ctx, clauses.reductionSyms ), clauses.untied );
2689+ TaskloopOp::build (builder, state, clauses.allocateVars , clauses.allocatorVars ,
2690+ clauses.final , clauses.grainsizeMod , clauses.grainsize ,
2691+ clauses.ifExpr , clauses.inReductionVars ,
2692+ makeDenseBoolArrayAttr (ctx, clauses.inReductionByref ),
2693+ makeArrayAttr (ctx, clauses.inReductionSyms ),
2694+ clauses.mergeable , clauses.nogroup , clauses.numTasksMod ,
2695+ clauses.numTasks , clauses.priority , /* private_vars=*/ {},
2696+ /* private_syms=*/ nullptr , clauses.reductionMod ,
2697+ clauses.reductionVars ,
2698+ makeDenseBoolArrayAttr (ctx, clauses.reductionByref ),
2699+ makeArrayAttr (ctx, clauses.reductionSyms ), clauses.untied );
26052700}
26062701
26072702SmallVector<Value> TaskloopOp::getAllReductionVars () {
0 commit comments