@@ -18,6 +18,7 @@ include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
1818include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
1919include "mlir/IR/OpBase.td"
2020include "mlir/Interfaces/SideEffectInterfaces.td"
21+ include "mlir/Interfaces/CallInterfaces.td"
2122
2223//===----------------------------------------------------------------------===//
2324// LLVM dialect type constraints.
@@ -286,22 +287,26 @@ class LLVM_MemAccessOpBase<string mnemonic, list<Trait> traits = []> :
286287// intrinsic and "enumName" contains the name of the intrinsic as appears in
287288// `llvm::Intrinsic` enum; one usually wants these to be related. Additionally,
288289// the base class also defines the "mlirBuilder" field to support the inverse
289- // translation starting from an LLVM IR intrinsic. The "requiresAccessGroup",
290- // "requiresAliasAnalysis", and "requiresFastmath" flags specify which
291- // interfaces the intrinsic implements. If the corresponding flags are set, the
292- // "aliasAttrs" list contains the arguments required by the access group and
293- // alias analysis interfaces. Derived intrinsics should append the "aliasAttrs"
294- // to their argument list if they set one of the flags. LLVM `immargs` can be
295- // represented as MLIR attributes by providing both the `immArgPositions` and
296- // `immArgAttrNames` lists. These two lists should have equal length, with
297- // `immArgPositions` containing the argument positions on the LLVM IR attribute
298- // that are `immargs`, and `immArgAttrNames` mapping these to corresponding
299- // MLIR attributes.
290+ // translation starting from an LLVM IR intrinsic.
291+ //
292+ // The flags "requiresAccessGroup", "requiresAliasAnalysis",
293+ // "requiresFastmath", and "requiresArgAndResultAttrs" indicate which
294+ // interfaces the intrinsic implements. When a flag is set, the "baseArgs"
295+ // list includes the arguments required by the corresponding interface.
296+ // Derived intrinsics must append "baseArgs" to their argument list if they
297+ // enable any of these flags.
298+ //
299+ // LLVM `immargs` can be represented as MLIR attributes by providing both
300+ // the `immArgPositions` and `immArgAttrNames` lists. These two lists should
301+ // have equal length, with `immArgPositions` containing the argument
302+ // positions on the LLVM IR attribute that are `immargs`, and
303+ // `immArgAttrNames` mapping these to corresponding MLIR attributes.
300304class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
301305                      list<int> overloadedResults, list<int> overloadedOperands,
302306                      list<Trait> traits, int numResults,
303307                      bit requiresAccessGroup = 0, bit requiresAliasAnalysis = 0,
304-                       bit requiresFastmath = 0, bit requiresOpBundles = 0,
308+                       bit requiresFastmath = 0, bit requiresArgAndResultAttrs = 0,
309+                       bit requiresOpBundles = 0,
305310                      list<int> immArgPositions = [],
306311                      list<string> immArgAttrNames = []>
307312    : LLVM_OpBase<dialect, opName, !listconcat(
@@ -311,24 +316,30 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
311316            [DeclareOpInterfaceMethods<AliasAnalysisOpInterface>], []),
312317        !if(!gt(requiresFastmath, 0),
313318            [DeclareOpInterfaceMethods<FastmathFlagsInterface>], []),
319+         !if(!gt(requiresArgAndResultAttrs, 0),
320+             [DeclareOpInterfaceMethods<ArgAndResultAttrsOpInterface>], []),
314321        traits)>,
315322      LLVM_MemOpPatterns,
316323      Results<!if(!gt(numResults, 0), (outs LLVM_Type:$res), (outs))> {
317-   dag aliasAttrs  = !con(
324+   dag baseArgs  = !con(
318325        !if(!gt(requiresAccessGroup, 0),
319326            (ins OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups),
320327            (ins )),
321328        !if(!gt(requiresAliasAnalysis, 0),
322329            (ins OptionalAttr<LLVM_AliasScopeArrayAttr>:$alias_scopes,
323330                 OptionalAttr<LLVM_AliasScopeArrayAttr>:$noalias_scopes,
324331                 OptionalAttr<LLVM_TBAATagArrayAttr>:$tbaa),
332+             (ins )),
333+         !if(!gt(requiresArgAndResultAttrs, 0),
334+             (ins OptionalAttr<DictArrayAttr>:$arg_attrs,
335+                  OptionalAttr<DictArrayAttr>:$res_attrs),
336+             (ins )),
337+         !if(!gt(requiresOpBundles, 0),
338+             (ins VariadicOfVariadic<LLVM_Type,
339+                   "op_bundle_sizes">:$op_bundle_operands,
340+                  DenseI32ArrayAttr:$op_bundle_sizes,
341+                  OptionalAttr<ArrayAttr>:$op_bundle_tags),
325342            (ins )));
326-   dag opBundleArgs = !if(!gt(requiresOpBundles, 0),
327-                          (ins VariadicOfVariadic<LLVM_Type,
328-                                 "op_bundle_sizes">:$op_bundle_operands,
329-                               DenseI32ArrayAttr:$op_bundle_sizes,
330-                               OptionalAttr<ArrayAttr>:$op_bundle_tags),
331-                          (ins ));
332343  string llvmEnumName = enumName;
333344  string overloadedResultsCpp =  "{" # !interleave(overloadedResults, ", ") # "}";
334345  string overloadedOperandsCpp =  "{" # !interleave(overloadedOperands, ", ") # "}";
@@ -342,33 +353,52 @@ class LLVM_IntrOpBase<Dialect dialect, string opName, string enumName,
342353        immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
343354    (void) inst;
344355    }];
356+   string baseLlvmBuilderArgAndResultAttrs = [{
357+     if (failed(moduleTranslation.convertArgAndResultAttrs(
358+         op,
359+         inst,
360+         }] # immArgPositionsCpp # [{))) {
361+       return failure();
362+     }
363+   }];
345364  string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
346-   let llvmBuilder =  baseLlvmBuilder # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
347-        # !if(!gt(requiresAliasAnalysis, 0), setAliasAnalysisMetadataCode, "")
348-        # baseLlvmBuilderCoda;
365+   let llvmBuilder = baseLlvmBuilder
366+       # !if(!gt(requiresAccessGroup, 0),
367+         setAccessGroupsMetadataCode, "")
368+       # !if(!gt(requiresAliasAnalysis, 0),
369+         setAliasAnalysisMetadataCode, "")
370+       # !if(!gt(requiresArgAndResultAttrs, 0),
371+         baseLlvmBuilderArgAndResultAttrs, "")
372+       # baseLlvmBuilderCoda;
349373
350374  string baseMlirBuilder = [{
351375    SmallVector<Value> mlirOperands;
352376    SmallVector<NamedAttribute> mlirAttrs;
353377    if (failed(moduleImport.convertIntrinsicArguments(
354-       llvmOperands,
355-       llvmOpBundles,
356-       }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{,
357-       }] # immArgPositionsCpp # [{,
358-       }] # immArgAttrNamesCpp # [{,
359-       mlirOperands,
360-       mlirAttrs))
361-     ) {
378+         llvmOperands,
379+         llvmOpBundles,
380+         }] # !if(!gt(requiresOpBundles, 0), "true", "false") # [{,
381+         }] # immArgPositionsCpp # [{,
382+         }] # immArgAttrNamesCpp # [{,
383+         mlirOperands,
384+         mlirAttrs))) {
362385      return failure();
363386    }
364387    SmallVector<Type> resultTypes =
365388    }] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{
366389    auto op = $_qualCppClassName::create($_builder,
367390      $_location, resultTypes, mlirOperands, mlirAttrs);
368391    }];
392+   string baseMlirBuilderArgAndResultAttrs = [{
393+     moduleImport.convertArgAndResultAttrs(
394+       inst, op, }] # immArgPositionsCpp # [{);
395+     }];
369396  string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;");
370-   let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0),
397+   let mlirBuilder = baseMlirBuilder
398+     # !if(!gt(requiresFastmath, 0),
371399      "moduleImport.setFastmathFlagsAttr(inst, op);", "")
400+     # !if(!gt(requiresArgAndResultAttrs, 0),
401+       baseMlirBuilderArgAndResultAttrs, "")
372402    # baseMlirBuilderCoda;
373403
374404  // Code for handling a `range` attribute that holds the constant range of the
@@ -399,14 +429,14 @@ class LLVM_IntrOp<string mnem, list<int> overloadedResults,
399429                  list<int> overloadedOperands, list<Trait> traits,
400430                  int numResults, bit requiresAccessGroup = 0,
401431                  bit requiresAliasAnalysis = 0, bit requiresFastmath = 0,
402-                   bit requiresOpBundles = 0,
432+                   bit requiresArgAndResultAttrs = 0, bit  requiresOpBundles = 0,
403433                  list<int> immArgPositions = [],
404434                  list<string> immArgAttrNames = []>
405435    : LLVM_IntrOpBase<LLVM_Dialect, "intr." # mnem, !subst(".", "_", mnem),
406436                      overloadedResults, overloadedOperands, traits,
407437                      numResults, requiresAccessGroup, requiresAliasAnalysis,
408-                       requiresFastmath, requiresOpBundles, immArgPositions ,
409-                       immArgAttrNames>;
438+                       requiresFastmath, requiresArgAndResultAttrs ,
439+                       requiresOpBundles, immArgPositions,  immArgAttrNames>;
410440
411441// Base class for LLVM intrinsic operations returning no results. Places the
412442// intrinsic into the LLVM dialect and prefixes its name with "intr.".
@@ -426,13 +456,14 @@ class LLVM_ZeroResultIntrOp<string mnem, list<int> overloadedOperands = [],
426456                            list<Trait> traits = [],
427457                            bit requiresAccessGroup = 0,
428458                            bit requiresAliasAnalysis = 0,
459+                             bit requiresArgAndResultAttrs = 0,
429460                            bit requiresOpBundles = 0,
430461                            list<int> immArgPositions = [],
431462                            list<string> immArgAttrNames = []>
432463    : LLVM_IntrOp<mnem, [], overloadedOperands, traits, /*numResults=*/0,
433464                  requiresAccessGroup, requiresAliasAnalysis,
434-                   /*requiresFastMath=*/0, requiresOpBundles, immArgPositions ,
435-                   immArgAttrNames>;
465+                   /*requiresFastMath=*/0, requiresArgAndResultAttrs ,
466+                   requiresOpBundles, immArgPositions,  immArgAttrNames>;
436467
437468// Base class for LLVM intrinsic operations returning one result. Places the
438469// intrinsic into the LLVM dialect and prefixes its name with "intr.". This is
@@ -448,7 +479,8 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
448479                           list<string> immArgAttrNames = []>
449480    : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
450481                  /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
451-                   requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
482+                   requiresFastmath, /*requiresArgAndResultAttrs=*/0,
483+                   /*requiresOpBundles=*/0, immArgPositions,
452484                  immArgAttrNames>;
453485
454486// Base class for LLVM intrinsic operations returning two results. Places the
@@ -465,7 +497,8 @@ class LLVM_TwoResultIntrOp<string mnem, list<int> overloadedResults = [],
465497                           list<string> immArgAttrNames = []>
466498    : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 2,
467499                  /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
468-                   requiresFastmath, /*requiresOpBundles=*/0, immArgPositions,
500+                   requiresFastmath, /*requiresArgAndResultAttrs=*/0,
501+                   /*requiresOpBundles=*/0, immArgPositions,
469502                  immArgAttrNames>;
470503
471504def LLVM_OneResultOpBuilder :
0 commit comments