Skip to content

Commit 8437c77

Browse files
committed
create InputAddressIsCombinationOf
1 parent 50d6536 commit 8437c77

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,11 +3015,10 @@ def NVVM_GriddepcontrolLaunchDependentsOp
30153015
// NVVM Mapa Op
30163016
//===----------------------------------------------------------------------===//
30173017

3018-
def NVVM_MapaASCheck : PredOpTrait<"Valid address-space check(or mapping) for mapa Op",
3019-
Or<[InputMatchesTypes<["a", "res"], [LLVM_PointerShared, LLVM_PointerSharedCluster]>.predicate,
3020-
InputMatchesTypes<["a", "res"], [LLVM_PointerGeneric, LLVM_PointerGeneric]>.predicate]>>;
3021-
3022-
def NVVM_MapaOp: NVVM_Op<"mapa", [NVVM_MapaASCheck, NVVMRequiresSM<90>]> {
3018+
def NVVM_MapaOp: NVVM_Op<"mapa",
3019+
[InputAddressIsCombinationOf<["a", "res"],
3020+
[[LLVM_PointerShared, LLVM_PointerSharedCluster], [LLVM_PointerGeneric, LLVM_PointerGeneric]],
3021+
"Valid address-space check(or mapping) for mapa Op">, NVVMRequiresSM<90>]> {
30233022
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
30243023
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
30253024

mlir/include/mlir/IR/OpBase.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,35 @@ class InputMatchesTypes<list<string> inputArgs, list<Type> allowedTypes> :
619619
list<Type> allowedTypeList = allowedTypes;
620620
}
621621

622+
// Checks that inputArgs match one of the allowed type combinations.
623+
// Each combination in allowedCombinations must have the same number of types
624+
// as there are inputArgs.
625+
class InputAddressIsCombinationOf<list<string> inputArgs,
626+
list<list<Type>> allowedCombinations,
627+
string description = ""> :
628+
PredOpTrait<!if(!empty(description),
629+
"operands {" # !interleave(inputArgs, ", ") # "} match one of the allowed type combinations",
630+
description),
631+
Or<!foreach(combination, allowedCombinations,
632+
!foldl(TruePred, !range(!size(inputArgs)), acc, i,
633+
And<[acc,
634+
SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
635+
combination[i].predicate>
636+
]>))>> {
637+
assert !gt(!size(allowedCombinations), 0),
638+
"allowedCombinations must not be empty";
639+
640+
// Validate that each combination has the same number of types as inputArgs
641+
defvar inputArgSize = !size(inputArgs);
642+
defvar validSizes = !foldl(1, allowedCombinations, acc, combination,
643+
!and(acc, !eq(inputArgSize, !size(combination))));
644+
assert validSizes,
645+
"each combination in allowedCombinations must have the same length as inputArgs";
646+
647+
list<string> inputArgList = inputArgs;
648+
list<list<Type>> allowedCombinationList = allowedCombinations;
649+
}
650+
622651
// Type Constraint operand `idx`'s Element type is `type`.
623652
class TCopVTEtIs<int idx, Type type> : And<[
624653
CPred<"$_op.getNumOperands() > " # idx>,

0 commit comments

Comments
 (0)