Skip to content

Commit 40815be

Browse files
authored
[MLIR][NVVM] Add support for st.bulk Op (llvm#131727)
This change adds the `st.bulk` NVVM Op for the `st.bulk` instruction introduced in ptx8.6 for sm_100. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk
1 parent 7a5ce55 commit 40815be

File tree

6 files changed

+77
-0
lines changed

6 files changed

+77
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2630,6 +2630,36 @@ def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
26302630
let hasVerifier = 1;
26312631
}
26322632

2633+
//===----------------------------------------------------------------------===//
2634+
// NVVM Bulk Store Op
2635+
//===----------------------------------------------------------------------===//
2636+
2637+
def NVVM_BulkStoreOp: NVVM_Op<"st.bulk"> {
2638+
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr, I64:$size, DefaultValuedAttr<I64Attr, "0">:$initVal);
2639+
2640+
let summary = "Bulk Store Op";
2641+
let description = [{
2642+
Initializes a region of shared memory at the address given by `addr`.
2643+
The `size` operand specifies the number of bytes to initialize and must be
2644+
a multiple of 8.
2645+
The `initVal` operand specifies the value to initialize the memory to. The
2646+
only supported value is 0.
2647+
2648+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st-bulk)
2649+
}];
2650+
2651+
string llvmBuilder = [{
2652+
auto intId = getStBulkIntrinsicId(
2653+
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType()));
2654+
createIntrinsicCall(builder, intId,
2655+
{$addr, $size, builder.getInt64($initVal)});
2656+
}];
2657+
2658+
let assemblyFormat = "$addr `,` `size` `=` $size (`,` `init` `=` $initVal^)? attr-dict `:` type($addr)";
2659+
2660+
let hasVerifier = 1;
2661+
}
2662+
26332663
def NVVM_Exit : NVVM_Op<"exit"> {
26342664
let summary = "Exit Op";
26352665
let description = [{

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ LogicalResult CvtFloatToTF32Op::verify() {
160160
return success();
161161
}
162162

163+
LogicalResult BulkStoreOp::verify() {
164+
if (getInitVal() != 0)
165+
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
166+
return success();
167+
}
168+
163169
// Given the element type of an operand and whether or not it is an accumulator,
164170
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
165171
// operand's element type.

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,15 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
150150
}
151151
}
152152

153+
/// Return the intrinsic ID associated with st.bulk for the given address type.
154+
static llvm::Intrinsic::ID
155+
getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
156+
bool isSharedMemory =
157+
addrType.getAddressSpace() == NVVM::NVVMMemorySpace::kSharedMemorySpace;
158+
return isSharedMemory ? llvm::Intrinsic::nvvm_st_bulk_shared_cta
159+
: llvm::Intrinsic::nvvm_st_bulk;
160+
}
161+
153162
static unsigned getUnidirectionalFenceProxyID(NVVM::ProxyKind fromProxy,
154163
NVVM::ProxyKind toProxy,
155164
NVVM::MemScopeKind scope,

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,15 @@ func.func @match_sync(%val32: i32, %val64: i64, %thread_mask: i32) {
563563
return
564564
}
565565

566+
// CHECK-LABEL: @st_bulk
567+
func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64) {
568+
// CHECK: nvvm.st.bulk %{{.*}}, size = %{{.*}} : !llvm.ptr
569+
nvvm.st.bulk %addr_gen, size = %size, init = 0 : !llvm.ptr
570+
// CHECK: nvvm.st.bulk %{{.*}}, size = %{{.*}} : !llvm.ptr<3>
571+
nvvm.st.bulk %addr_shared, size = %size, init = 0 : !llvm.ptr<3>
572+
return
573+
}
574+
566575
// -----
567576

568577
// Just check these don't emit errors.

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
125125

126126
// -----
127127

128+
llvm.func @nvvm_st_bulk_initval_nonzero(%addr : !llvm.ptr, %size : i64) {
129+
// expected-error @below {{only 0 is supported for initVal, got 1}}
130+
nvvm.st.bulk %addr, size = %size, init = 1 : !llvm.ptr
131+
llvm.return
132+
}
133+
134+
// -----
135+
128136
llvm.func @nvvm_tcgen05_cp_128x256b_mc(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
129137
// expected-error @below {{Invalid multicast type for tcgen05.cp Op}}
130138
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>}

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ llvm.func @nvvm_redux_sync_f32(%value: f32, %offset: i32) {
811811
llvm.return
812812
}
813813

814+
// -----
814815
// CHECK-LABEL: @nvvm_match_sync
815816
llvm.func @nvvm_match_sync(%mask: i32, %val32: i32, %val64: i64) {
816817
// CHECK: call i32 @llvm.nvvm.match.any.sync.i32(i32 %{{.*}}, i32 %{{.*}})
@@ -823,3 +824,17 @@ llvm.func @nvvm_match_sync(%mask: i32, %val32: i32, %val64: i64) {
823824
%3 = nvvm.match.sync all %mask, %val64 : i64 -> !llvm.struct<(i32, i1)>
824825
llvm.return
825826
}
827+
828+
// -----
829+
// CHECK-LABEL: @nvvm_st_bulk
830+
llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64) {
831+
// CHECK: call void @llvm.nvvm.st.bulk(ptr %{{.*}}, i64 %{{.*}}, i64 0)
832+
nvvm.st.bulk %addr_gen, size = %size : !llvm.ptr
833+
// CHECK: call void @llvm.nvvm.st.bulk.shared.cta(ptr addrspace(3) %{{.*}}, i64 %{{.*}}, i64 0)
834+
nvvm.st.bulk %addr_shared, size = %size: !llvm.ptr<3>
835+
// CHECK: call void @llvm.nvvm.st.bulk(ptr %{{.*}}, i64 %{{.*}}, i64 0)
836+
nvvm.st.bulk %addr_gen, size = %size, init = 0 : !llvm.ptr
837+
// CHECK: call void @llvm.nvvm.st.bulk.shared.cta(ptr addrspace(3) %{{.*}}, i64 %{{.*}}, i64 0)
838+
nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
839+
llvm.return
840+
}

0 commit comments

Comments
 (0)