diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 11143151ddd85..23db9375fbffe 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2539,6 +2539,28 @@ def NVVM_GriddepcontrolLaunchDependentsOp }]; } +//===----------------------------------------------------------------------===// +// NVVM Mapa Op +//===----------------------------------------------------------------------===// + +def NVVM_MapaOp: NVVM_Op<"mapa", + [TypesMatchWith<"`res` and `a` should have the same type", + "a", "res", "$_self">]> { + let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res); + let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b); + + string llvmBuilder = [{ + int addrSpace = llvm::cast(op.getA().getType()).getAddressSpace(); + + bool isSharedMemory = addrSpace == NVVM::NVVMMemorySpace::kSharedMemorySpace; + + auto intId = isSharedMemory? llvm::Intrinsic::nvvm_mapa_shared_cluster : llvm::Intrinsic::nvvm_mapa; + $res = createIntrinsicCall(builder, intId, {$a, $b}); + }]; + + let assemblyFormat = "$a`,` $b attr-dict `:` type($a) `->` type($res)"; +} + def NVVM_Exit : NVVM_Op<"exit"> { let summary = "Exit Op"; let description = [{ diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 25806d9d0edd7..5c939318fe3ed 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1189,6 +1189,14 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { // ----- +func.func @mapa(%a: !llvm.ptr, %b : i32) { + // expected-error @below {{`res` and `a` should have the same type}} + %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3> + return +} + +// ----- + func.func @gep_struct_variable(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32) { // expected-error @below {{op expected index 1 indexing a struct to be constant}} llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr, i32, i32) -> !llvm.ptr, !llvm.struct<(i32)> diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index 7d1efdfa44150..dd54acd1e317e 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -522,6 +522,15 @@ func.func @griddepcontrol_launch_dependents() return } +// CHECK-LABEL: @mapa +func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { + // CHECK: nvvm.mapa %{{.*}} + %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr + // CHECK: nvvm.mapa %{{.*}} + %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> + return +} + // ----- // Just check these don't emit errors. diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 99a71748b0a16..970cac707b058 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -773,3 +773,13 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() { nvvm.griddepcontrol.launch.dependents llvm.return } + +// ----- +// CHECK-LABEL: @nvvm_mapa +llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { + // CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}}) + %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr + // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> + llvm.return +}