@@ -38,13 +38,19 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
3838 return success ();
3939}
4040
41- mlir::Attribute addElemToArrayAttr (mlir::PatternRewriter &rewriter,
42- mlir::Attribute attr,
43- mlir::Attribute element) {
44- assert (isa<ArrayAttr>(attr));
45- auto values = cast<ArrayAttr>(attr).getValue ().vec ();
46- values.push_back (element);
47- return rewriter.getArrayAttr (values);
41+ LogicalResult addElemToArrayAttr (PatternRewriter &rewriter,
42+ PDLResultList &results,
43+ ArrayRef<PDLValue> args) {
44+
45+ assert (args.size () == 2 &&
46+ " Expected two arguments, one ArrayAttr and one Attr" );
47+ auto arrayAttr = cast<ArrayAttr>(args[0 ].cast <Attribute>());
48+ auto attrElement = args[1 ].cast <Attribute>();
49+ llvm::SmallVector<Attribute> values (arrayAttr.getValue ());
50+ values.push_back (attrElement);
51+
52+ results.push_back (rewriter.getArrayAttr (values));
53+ return success ();
4854}
4955
5056template <UnaryOpKind T>
@@ -344,11 +350,15 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
344350 // See Parser::defineBuiltins()
345351 pdlPattern.registerRewriteFunction (
346352 " __builtin_addEntryToDictionaryAttr_rewrite" , addEntryToDictionaryAttr);
347- pdlPattern.registerRewriteFunction (" __builtin_addElemToArrayAttr" ,
348- addElemToArrayAttr);
349353 pdlPattern.registerConstraintFunction (
350354 " __builtin_addEntryToDictionaryAttr_constraint" ,
351355 addEntryToDictionaryAttr);
356+
357+ pdlPattern.registerRewriteFunction (" __builtin_addElemToArrayAttrRewriter" ,
358+ addElemToArrayAttr);
359+ pdlPattern.registerConstraintFunction (
360+ " __builtin_addElemToArrayAttrConstraint" , addElemToArrayAttr);
361+
352362 pdlPattern.registerRewriteFunction (" __builtin_mulRewrite" , mul);
353363 pdlPattern.registerRewriteFunction (" __builtin_divRewrite" , div);
354364 pdlPattern.registerRewriteFunction (" __builtin_modRewrite" , mod);
@@ -357,22 +367,14 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
357367 pdlPattern.registerRewriteFunction (" __builtin_log2Rewrite" , log2);
358368 pdlPattern.registerRewriteFunction (" __builtin_exp2Rewrite" , exp2);
359369 pdlPattern.registerRewriteFunction (" __builtin_absRewrite" , abs);
360- pdlPattern.registerConstraintFunction (" __builtin_mulConstraint" ,
361- mul);
362- pdlPattern.registerConstraintFunction (" __builtin_divConstraint" ,
363- div);
364- pdlPattern.registerConstraintFunction (" __builtin_modConstraint" ,
365- mod);
366- pdlPattern.registerConstraintFunction (" __builtin_addConstraint" ,
367- add);
368- pdlPattern.registerConstraintFunction (" __builtin_subConstraint" ,
369- sub);
370- pdlPattern.registerConstraintFunction (" __builtin_log2Constraint" ,
371- log2);
372- pdlPattern.registerConstraintFunction (" __builtin_exp2Constraint" ,
373- exp2);
374- pdlPattern.registerConstraintFunction (" __builtin_absConstraint" ,
375- abs);
370+ pdlPattern.registerConstraintFunction (" __builtin_mulConstraint" , mul);
371+ pdlPattern.registerConstraintFunction (" __builtin_divConstraint" , div);
372+ pdlPattern.registerConstraintFunction (" __builtin_modConstraint" , mod);
373+ pdlPattern.registerConstraintFunction (" __builtin_addConstraint" , add);
374+ pdlPattern.registerConstraintFunction (" __builtin_subConstraint" , sub);
375+ pdlPattern.registerConstraintFunction (" __builtin_log2Constraint" , log2);
376+ pdlPattern.registerConstraintFunction (" __builtin_exp2Constraint" , exp2);
377+ pdlPattern.registerConstraintFunction (" __builtin_absConstraint" , abs);
376378 pdlPattern.registerConstraintFunction (" __builtin_equals" , equals);
377379}
378380} // namespace mlir::pdl
0 commit comments