Skip to content

Commit 38d854c

Browse files
authored
[MLIR][NVVM] Update MLIR mapa to reflect new address space (#146031)
The mapa.shared.cluster variant that takes in address-space 3 now should output address-space 7. This patch updates the NVVMOps.td file to reflect this.
1 parent 9a17451 commit 38d854c

File tree

5 files changed

+54
-8
lines changed

5 files changed

+54
-8
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3068,9 +3068,10 @@ def NVVM_GriddepcontrolLaunchDependentsOp
30683068
//===----------------------------------------------------------------------===//
30693069

30703070
def NVVM_MapaOp: NVVM_Op<"mapa",
3071-
[TypesMatchWith<"`res` and `a` should have the same type",
3072-
"a", "res", "$_self">, NVVMRequiresSM<90>]> {
3073-
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
3071+
[InputAddressIsCombinationOf<["a", "res"],
3072+
[[LLVM_PointerShared, LLVM_PointerSharedCluster], [LLVM_PointerGeneric, LLVM_PointerGeneric]],
3073+
"Valid address-space check(or mapping) for mapa Op">, NVVMRequiresSM<90>]> {
3074+
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res);
30743075
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
30753076

30763077
string llvmBuilder = [{

mlir/include/mlir/IR/OpBase.td

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,51 @@ class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,
603603
string transform>
604604
: TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">;
605605

606+
// Checks that each inputArg has the same type as the corresponding entry
607+
// in allowedTypes
608+
class InputMatchesTypes<list<string> inputArgs, list<Type> allowedTypes> :
609+
PredOpTrait<"operands {" # !interleave(inputArgs, ", ") # "} match expected types",
610+
!foldl(TruePred, !range(!size(inputArgs)), acc, i,
611+
And<[acc,
612+
SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
613+
allowedTypes[i].predicate>
614+
]>)> {
615+
assert !eq(!size(inputArgs), !size(allowedTypes)),
616+
"inputArgs and allowedTypes lists must have the same length";
617+
618+
list<string> inputArgList = inputArgs;
619+
list<Type> allowedTypeList = allowedTypes;
620+
}
621+
622+
// Checks that inputArgs match one of the allowed type combinations.
623+
// Each combination in allowedCombinations must have the same number of types
624+
// as there are inputArgs.
625+
class InputAddressIsCombinationOf<list<string> inputArgs,
626+
list<list<Type>> allowedCombinations,
627+
string description = ""> :
628+
PredOpTrait<!if(!empty(description),
629+
"operands {" # !interleave(inputArgs, ", ") # "} match one of the allowed type combinations",
630+
description),
631+
Or<!foreach(combination, allowedCombinations,
632+
!foldl(TruePred, !range(!size(inputArgs)), acc, i,
633+
And<[acc,
634+
SubstLeaves<"$_self", "$" # inputArgs[i] # ".getType()",
635+
combination[i].predicate>
636+
]>))>> {
637+
assert !gt(!size(allowedCombinations), 0),
638+
"allowedCombinations must not be empty";
639+
640+
// Validate that each combination has the same number of types as inputArgs
641+
defvar inputArgSize = !size(inputArgs);
642+
defvar validSizes = !foldl(1, allowedCombinations, acc, combination,
643+
!and(acc, !eq(inputArgSize, !size(combination))));
644+
assert validSizes,
645+
"each combination in allowedCombinations must have the same length as inputArgs";
646+
647+
list<string> inputArgList = inputArgs;
648+
list<list<Type>> allowedCombinationList = allowedCombinations;
649+
}
650+
606651
// Type Constraint operand `idx`'s Element type is `type`.
607652
class TCopVTEtIs<int idx, Type type> : And<[
608653
CPred<"$_op.getNumOperands() > " # idx>,

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,8 +1275,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
12751275
// -----
12761276

12771277
func.func @mapa(%a: !llvm.ptr, %b : i32) {
1278-
// expected-error @below {{`res` and `a` should have the same type}}
1279-
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3>
1278+
// expected-error @below {{'nvvm.mapa' op failed to verify that Valid address-space check(or mapping) for mapa Op}}
1279+
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7>
12801280
return
12811281
}
12821282

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
541541
// CHECK: nvvm.mapa %{{.*}}
542542
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
543543
// CHECK: nvvm.mapa %{{.*}}
544-
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
544+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
545545
return
546546
}
547547

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,8 +813,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() {
813813
llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) {
814814
// CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}})
815815
%0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr
816-
// CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
817-
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3>
816+
// CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
817+
%1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7>
818818
llvm.return
819819
}
820820

0 commit comments

Comments
 (0)