Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2130,7 +2130,7 @@ def NVVM_CpAsyncBulkTensorReduceOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//

def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
Expand All @@ -2139,34 +2139,34 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_PTXBuilder_Op<"wgmma.fence.aligned"> {
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-fence)
}];
let assemblyFormat = "attr-dict";
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("wgmma.fence.sync.aligned;"); }
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_fence_sync_aligned);
}];
}

def NVVM_WgmmaGroupSyncAlignedOp : NVVM_PTXBuilder_Op<"wgmma.commit.group.sync.aligned">,
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
let description = [{
Commits all prior uncommitted warpgroup level matrix multiplication operations.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-commit-group)
}];
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("wgmma.commit_group.sync.aligned;"); }
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_commit_group_sync_aligned);
}];
}

def NVVM_WgmmaWaitGroupSyncOp : NVVM_PTXBuilder_Op<"wgmma.wait.group.sync.aligned">{
let arguments = (ins I32Attr:$group);
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
let arguments = (ins I64Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
Signal the completion of a preceding warpgroup operation.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions-wgmma-wait-group)
}];
let extraClassDefinition = [{
std::string $cppClass::getPtx() { return std::string("wgmma.wait_group.sync.aligned %0;"); }
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_wgmma_wait_group_sync_aligned, builder.getInt64($group));
}];
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {

let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA,
NVGPU_WarpgroupMatrixDescriptor:$descriptorB,
DefaultValuedOptionalAttr<I32Attr, "1">:$waitGroup,
DefaultValuedOptionalAttr<I64Attr, "1">:$waitGroup,
OptionalAttr<UnitAttr>:$transposeA,
OptionalAttr<UnitAttr>:$transposeB,
NVGPU_WarpgroupAccumulator:$matrixC);
Expand Down
14 changes: 6 additions & 8 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -266,19 +266,17 @@ func.func @wgmma_execute() {
nvvm.wgmma.fence.aligned
nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 0
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
// CHECK: %[[S0:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S0]] : (i32)
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: nvvm.wgmma.commit.group.sync.aligned
// CHECK: nvvm.wgmma.wait.group.sync.aligned 0


nvvm.wgmma.fence.aligned
nvvm.wgmma.commit.group.sync.aligned
nvvm.wgmma.wait.group.sync.aligned 5
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.fence.sync.aligned;"
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.commit_group.sync.aligned;"
// CHECK: %[[S1:.+]] = llvm.mlir.constant(5 : i32) : i32
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "wgmma.wait_group.sync.aligned $0;", "n" %[[S1]] : (i32)
// CHECK: nvvm.wgmma.fence.aligned
// CHECK: nvvm.wgmma.commit.group.sync.aligned
// CHECK: nvvm.wgmma.wait.group.sync.aligned 5
return
}

Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,29 @@ llvm.func @nvvm_breakpoint() {
nvvm.breakpoint
llvm.return
}

// -----
// CHECK-LABEL: @nvvm_wgmma_fence_aligned
llvm.func @nvvm_wgmma_fence_aligned() {
// CHECK: call void @llvm.nvvm.wgmma.fence.sync.aligned()
nvvm.wgmma.fence.aligned
llvm.return
}

// -----
// CHECK-LABEL: @nvvm_wgmma_commit_group_aligned
llvm.func @nvvm_wgmma_commit_group_aligned() {
// CHECK: call void @llvm.nvvm.wgmma.commit_group.sync.aligned()
nvvm.wgmma.commit.group.sync.aligned
llvm.return
}

// -----
// CHECK-LABEL: @nvvm_wgmma_wait_group_aligned
llvm.func @nvvm_wgmma_wait_group_aligned() {
// CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 0)
nvvm.wgmma.wait.group.sync.aligned 0
// CHECK: call void @llvm.nvvm.wgmma.wait_group.sync.aligned(i64 20)
nvvm.wgmma.wait.group.sync.aligned 20
llvm.return
}
Loading