Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,49 @@ def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
}];
}

def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
let summary = "Async bulk prefetch from global memory to L2 cache";
let description = [{
Initiates an asynchronous prefetch of data from the location
specified by `srcMem` to the L2 cache.

The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.

Example:
```mlir
nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>

// with l2_cache_hint
nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
```

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch)
}];

let arguments = (ins
LLVM_PointerGlobal:$srcMem,
I32:$size,
Optional<I64>:$l2CacheHint);

let assemblyFormat = [{
$srcMem `,` $size (`l2_cache_hint` `=` $l2CacheHint^ )?
attr-dict `:` type($srcMem)
}];

let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
}];

string llvmBuilder = [{
auto [id, args] = NVVM::CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
createIntrinsicCall(builder, id, args);
}];
}

def NVVM_CpAsyncBulkTensorPrefetchOp :
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
let arguments = (ins
Expand Down
20 changes: 20 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,26 @@ CpAsyncOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
return id;
}

mlir::NVVM::IDArgPair CpAsyncBulkPrefetchOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkPrefetchOp>(op);
llvm::SmallVector<llvm::Value *> args;
llvm::Intrinsic::ID id = llvm::Intrinsic::nvvm_cp_async_bulk_prefetch_L2;

// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getSrcMem()));
args.push_back(mt.lookupValue(thisOp.getSize()));

mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Unused =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
args.push_back(builder.getInt1(hasCacheHint));

return {id, std::move(args)};
}

mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkSharedCTAToGlobalOp>(op);
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: @tma_bulk_prefetch
llvm.func @tma_bulk_prefetch(%src : !llvm.ptr<1>, %size : i32, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
// CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>
nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
llvm.return
}

// CHECK-LABEL: @tma_prefetch_1d
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 0, i1 false)
Expand Down