diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 8537c7030aa8f..9d93b4efe7a5b 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2368,6 +2368,23 @@ def TMAStoreModeAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } +// Num CTAs in a group participating in the TMA/MMA operations. +// This corresponds to the "cta_group::1", "cta_group::2" +// modifiers in the PTX instructions. +def CTAGroup_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">; +def CTAGroup_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">; + +def CTAGroupKind : I32EnumAttr<"CTAGroupKind", + "NVVM CTA group kind", + [CTAGroup_1, CTAGroup_2]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def CTAGroupKindAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">, Arguments<(ins )> { let assemblyFormat = "attr-dict"; @@ -3333,23 +3350,6 @@ def NVVM_Breakpoint : NVVM_Op<"breakpoint"> { //===----------------------------------------------------------------------===// // NVVM TCGEN05 Ops //===----------------------------------------------------------------------===// -// Num CTAs in a group participating in the TCGEN05 operation. -// This corresponds to the "cta_group::1", "cta_group::2" -// modifiers in the PTX instructions. -def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">; -def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">; - -def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind", - "NVVM Tcgen05 group kind", - [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::NVVM"; -} -def Tcgen05GroupKindAttr : - EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - def Tcgen05FenceBefore : I32EnumAttrCase<"BEFORE_THREAD_SYNC", 0, "before">; def Tcgen05FenceAfter : I32EnumAttrCase<"AFTER_THREAD_SYNC", 1, "after">; def Tcgen05FenceKind : I32EnumAttr<"Tcgen05FenceKind", "NVVM Tcgen05 fence kind", @@ -3387,7 +3387,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>] let arguments = (ins AnyTypeOf<[LLVM_AnyPointer, LLVM_PointerShared]>:$addr, I32:$nCols, - DefaultValuedAttr:$group); + DefaultValuedAttr:$group); let assemblyFormat = "$addr `,` $nCols attr-dict `:` type(operands)"; @@ -3415,7 +3415,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 10 }]; let arguments = (ins LLVM_PointerTensor:$taddr, I32:$nCols, - DefaultValuedAttr:$group); + DefaultValuedAttr:$group); let assemblyFormat = "$taddr `,` $nCols attr-dict `:` type(operands)"; @@ -3443,12 +3443,12 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm }]; let arguments = (ins - DefaultValuedAttr:$group); + DefaultValuedAttr:$group); let assemblyFormat = "attr-dict"; string llvmBuilder = [{ - auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ? + auto id = ($group == NVVM::CTAGroupKind::CTA_1) ? llvm::Intrinsic::nvvm_tcgen05_relinq_alloc_permit_cg1 : llvm::Intrinsic::nvvm_tcgen05_relinq_alloc_permit_cg2; createIntrinsicCall(builder, id); @@ -3516,7 +3516,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101] let arguments = (ins AnyTypeOf<[LLVM_AnyPointer, LLVM_PointerShared]>:$addr, Optional:$multicastMask, - DefaultValuedAttr:$group); + DefaultValuedAttr:$group); let assemblyFormat = [{ $addr (`,` `multicast_mask` `=` $multicastMask^)? @@ -3549,12 +3549,12 @@ def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 1 }]; let arguments = (ins LLVM_PointerTensor:$taddr, - DefaultValuedAttr:$group); + DefaultValuedAttr:$group); let assemblyFormat = "$taddr attr-dict `:` type(operands)"; string llvmBuilder = [{ - auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ? + auto id = ($group == NVVM::CTAGroupKind::CTA_1) ? llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 : llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2; createIntrinsicCall(builder, id, {$taddr}); @@ -3626,7 +3626,7 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> { let arguments = (ins Tcgen05CpShapeAttr:$shape, - DefaultValuedAttr:$group, + DefaultValuedAttr:$group, DefaultValuedAttr:$multicast, OptionalAttr:$srcFormat, LLVM_PointerTensor:$taddr, diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 77ec1ebde3109..376e3c3e1fcbe 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1797,7 +1797,7 @@ Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, unsigned as = llvm::cast(curOp.getAddr().getType()) .getAddressSpace(); bool isShared = as == NVVMMemorySpace::kSharedMemorySpace; - bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2; + bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2; llvm::Intrinsic::ID id; if (isShared) { @@ -1819,7 +1819,7 @@ llvm::Intrinsic::ID Tcgen05DeallocOp::getIntrinsicIDAndArgs( Operation &op, LLVM::ModuleTranslation &mt, llvm::SmallVector &args) { auto curOp = cast(op); - auto id = (curOp.getGroup() == Tcgen05GroupKind::CTA_1) + auto id = (curOp.getGroup() == CTAGroupKind::CTA_1) ? llvm::Intrinsic::nvvm_tcgen05_dealloc_cg1 : llvm::Intrinsic::nvvm_tcgen05_dealloc_cg2; @@ -1847,7 +1847,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, .getAddressSpace(); bool isShared = as == NVVMMemorySpace::kSharedMemorySpace; bool hasMulticast = static_cast(curOp.getMulticastMask()); - bool is2CTAMode = curOp.getGroup() == Tcgen05GroupKind::CTA_2; + bool is2CTAMode = curOp.getGroup() == CTAGroupKind::CTA_2; llvm::Intrinsic::ID id = is2CTAMode ? GET_TCGEN05_COMMIT_ID(cg2, isShared, hasMulticast) @@ -1879,7 +1879,7 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op, llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) { auto curOp = cast(op); - bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2; + bool is2CTA = curOp.getGroup() == CTAGroupKind::CTA_2; auto srcFmt = curOp.getSrcFormat(); auto mc = curOp.getMulticast(); diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-alloc.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-alloc.mlir index 6a7e4ac515b81..a8f80296f20ae 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-alloc.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-alloc.mlir @@ -6,7 +6,7 @@ llvm.func @llvm_nvvm_tcgen05_alloc(%addr : !llvm.ptr, %ncols : i32) { nvvm.tcgen05.alloc %addr, %ncols : !llvm.ptr, i32 // CHECK-LLVM: call void @llvm.nvvm.tcgen05.alloc.cg2(ptr %{{.*}}, i32 %{{.*}}) - nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.tcgen05_group} : !llvm.ptr, i32 + nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.cta_group} : !llvm.ptr, i32 llvm.return } @@ -16,7 +16,7 @@ llvm.func @llvm_nvvm_tcgen05_alloc_shared(%addr : !llvm.ptr<3>, %ncols : i32) { nvvm.tcgen05.alloc %addr, %ncols : !llvm.ptr<3>, i32 // CHECK-LLVM: call void @llvm.nvvm.tcgen05.alloc.shared.cg2(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.tcgen05_group} : !llvm.ptr<3>, i32 + nvvm.tcgen05.alloc %addr, %ncols {group = #nvvm.cta_group} : !llvm.ptr<3>, i32 llvm.return } @@ -26,7 +26,7 @@ llvm.func @llvm_nvvm_tcgen05_dealloc(%addr : !llvm.ptr<6>, %ncols : i32) { nvvm.tcgen05.dealloc %addr, %ncols : !llvm.ptr<6>, i32 // CHECK-LLVM: call void @llvm.nvvm.tcgen05.dealloc.cg2(ptr addrspace(6) %{{.*}}, i32 %{{.*}}) - nvvm.tcgen05.dealloc %addr, %ncols {group = #nvvm.tcgen05_group} : !llvm.ptr<6>, i32 + nvvm.tcgen05.dealloc %addr, %ncols {group = #nvvm.cta_group} : !llvm.ptr<6>, i32 llvm.return } @@ -36,6 +36,6 @@ llvm.func @llvm_nvvm_tcgen05_relinquish_alloc_permit() { nvvm.tcgen05.relinquish_alloc_permit // CHECK-LLVM: call void @llvm.nvvm.tcgen05.relinq.alloc.permit.cg2() - nvvm.tcgen05.relinquish_alloc_permit {group = #nvvm.tcgen05_group} + nvvm.tcgen05.relinquish_alloc_permit {group = #nvvm.cta_group} llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-commit.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-commit.mlir index 80cf29f3704c2..60475bf64ae7a 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-commit.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-commit.mlir @@ -6,13 +6,13 @@ llvm.func @llvm_nvvm_tcgen05_commit_generic(%barrier : !llvm.ptr, %cta_mask : i1 nvvm.tcgen05.commit %barrier : !llvm.ptr // CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.cg2(ptr %{{.*}}) - nvvm.tcgen05.commit %barrier {group = #nvvm.tcgen05_group} : !llvm.ptr + nvvm.tcgen05.commit %barrier {group = #nvvm.cta_group} : !llvm.ptr // CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.cg1(ptr %{{.*}}, i16 %{{.*}}) nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask : !llvm.ptr, i16 // CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.cg2(ptr %{{.*}}, i16 %{{.*}}) - nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.tcgen05_group} : !llvm.ptr, i16 + nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.cta_group} : !llvm.ptr, i16 llvm.return } @@ -22,12 +22,12 @@ llvm.func @llvm_nvvm_tcgen05_commit_shared(%barrier : !llvm.ptr<3>, %cta_mask : nvvm.tcgen05.commit %barrier : !llvm.ptr<3> // CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.shared.cg2(ptr addrspace(3) %{{.*}}) - nvvm.tcgen05.commit %barrier {group = #nvvm.tcgen05_group} : !llvm.ptr<3> + nvvm.tcgen05.commit %barrier {group = #nvvm.cta_group} : !llvm.ptr<3> // CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.shared.cg1(ptr addrspace(3) %{{.*}}, i16 %{{.*}}) nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask : !llvm.ptr<3>, i16 // CHECK-LLVM: call void @llvm.nvvm.tcgen05.commit.mc.shared.cg2(ptr addrspace(3) %{{.*}}, i16 %{{.*}}) - nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.tcgen05_group} : !llvm.ptr<3>, i16 + nvvm.tcgen05.commit %barrier, multicast_mask = %cta_mask {group = #nvvm.cta_group} : !llvm.ptr<3>, i16 llvm.return } diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir index bf72714d16de7..237b15ba36739 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir @@ -6,18 +6,18 @@ llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) { nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape} // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) - nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape, group = #nvvm.tcgen05_group} + nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape, group = #nvvm.cta_group} // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, srcFormat = #nvvm.tcgen05_cp_src_fmt } // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, srcFormat = #nvvm.tcgen05_cp_src_fmt } llvm.return @@ -29,18 +29,18 @@ llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) { nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape} // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) - nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape, group = #nvvm.tcgen05_group} + nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape, group = #nvvm.cta_group} // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, srcFormat = #nvvm.tcgen05_cp_src_fmt } // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, srcFormat = #nvvm.tcgen05_cp_src_fmt } llvm.return @@ -52,18 +52,18 @@ llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) { nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape} // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) - nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape, group = #nvvm.tcgen05_group} + nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape, group = #nvvm.cta_group} // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, srcFormat = #nvvm.tcgen05_cp_src_fmt } // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, srcFormat = #nvvm.tcgen05_cp_src_fmt } llvm.return @@ -80,21 +80,21 @@ llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) { // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, multicast = #nvvm.tcgen05_cp_multicast } // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, multicast = #nvvm.tcgen05_cp_multicast, srcFormat = #nvvm.tcgen05_cp_src_fmt } // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, multicast = #nvvm.tcgen05_cp_multicast, srcFormat = #nvvm.tcgen05_cp_src_fmt } @@ -113,21 +113,21 @@ llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) { // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, multicast = #nvvm.tcgen05_cp_multicast } // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, multicast = #nvvm.tcgen05_cp_multicast, srcFormat = #nvvm.tcgen05_cp_src_fmt } // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}}) nvvm.tcgen05.cp %taddr, %smem_desc { shape = #nvvm.tcgen05_cp_shape, - group = #nvvm.tcgen05_group, + group = #nvvm.cta_group, multicast = #nvvm.tcgen05_cp_multicast, srcFormat = #nvvm.tcgen05_cp_src_fmt } diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir index 78c50cf96cf90..b8c33516b9135 100644 --- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir +++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir @@ -6,6 +6,6 @@ llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) { nvvm.tcgen05.shift %taddr : !llvm.ptr<6> // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}}) - nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group} : !llvm.ptr<6> + nvvm.tcgen05.shift %taddr {group = #nvvm.cta_group} : !llvm.ptr<6> llvm.return }