Skip to content

Commit f1032f0

Browse files
authored
[MLIR][NVVM][NVGPU] Combine prefetch and prefetch.tensormap (#153134)
This PR combines the `prefetch` and `prefetch.tensormap` NVVM Ops to one `prefetch` Op. The `tensormap` variant is lowered through the newly added intrinsics. The lowering of the NVGPU `tma.prefetch.descriptor` Op is changed from lowering to the `prefetch.tensormap` Op to `prefetch`. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu
1 parent 10e5ec8 commit f1032f0

File tree

8 files changed

+245
-75
lines changed

8 files changed

+245
-75
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
2525
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
2626
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
2727
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
28+
def LLVM_PointerConst : LLVM_PointerInAddressSpace<4>;
2829
def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
2930
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
3031
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
@@ -2570,15 +2571,25 @@ def PrefetchCacheLevelAttr : EnumAttr<NVVM_Dialect, PrefetchCacheLevel, "prefetc
25702571
let assemblyFormat = "$value";
25712572
}
25722573

2573-
def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
2574+
def NVVM_PrefetchOp : NVVM_Op<"prefetch",
2575+
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
25742576
let summary = "Brings the cache line containing an address into the specified cache level";
25752577
let description = [{
2576-
Operand `addr` can be a global, local or generic address pointer. No
2577-
operation is performed if `addr` maps to a `shared` memory location.
2578+
Prefetches the cache line containing the address given by `addr`. The
2579+
operand may be a global, local, or generic pointer. When `tensormap` is
2580+
specified, the operand may instead be a constant or generic pointer. If the
2581+
address maps to shared memory, the operation has no effect.
2582+
2583+
At most one of `cacheLevel` or `tensormap` may be present. The `cacheLevel`
2584+
attribute selects the target cache level. When combined with `uniform`, the
2585+
prefetch is performed to the uniform cache, in which case `addr` must be a
2586+
generic pointer.
2587+
2588+
When `tensormap` is used, the line containing `addr` is brought from the
2589+
constant or parameter state space for later use by `cp.async.bulk.tensor`.
2590+
If `in_param_space` is specified, the generic pointer is interpreted as
2591+
referring to the parameter state space.
25782592

2579-
The `cacheLevel` attribute specifies the cache level to which the cache line
2580-
containing the specified address is brought.
2581-
25822593
`uniform` can be specified after the `cacheLevel` to indicate that the
25832594
prefetch is performed to the specified uniform cache level. If `uniform` is
25842595
specified, `addr` must be a generic address pointer and no operation is
@@ -2589,33 +2600,41 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
25892600

25902601
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
25912602
}];
2592-
let arguments = (ins PrefetchCacheLevelAttr:$cacheLevel,
2593-
UnitAttr:$uniform,
2603+
let arguments = (ins OptionalAttr<PrefetchCacheLevelAttr>:$cacheLevel,
2604+
OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority,
25942605
AnyTypeOf<[LLVM_PointerGlobal,
25952606
LLVM_PointerLocal,
2596-
LLVM_PointerGeneric]>:$addr,
2597-
OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority);
2598-
let assemblyFormat = "`level` `=` $cacheLevel (`uniform` $uniform^)? `,` $addr (`,` `evict_priority` `=` $evictPriority^)? attr-dict `:` type($addr)";
2607+
LLVM_PointerGeneric,
2608+
LLVM_PointerConst]>:$addr,
2609+
PtxPredicate:$predicate,
2610+
UnitAttr:$tensormap,
2611+
UnitAttr:$uniform,
2612+
UnitAttr:$in_param_space);
2613+
let assemblyFormat = "(`level` `=` $cacheLevel^ (`uniform` $uniform^)? `,`)? (`tensormap` $tensormap^ (`in_param_space` $in_param_space^)? `,`)? (`evict_priority` `=` $evictPriority^ `,`)? $addr (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
25992614
let hasVerifier = 1;
26002615

26012616
let extraClassDeclaration = [{
2602-
static llvm::Intrinsic::ID getIntrinsicID(NVVM::PrefetchOp &op);
2603-
}];
2604-
let llvmBuilder = [{
2605-
auto intId = NVVM::PrefetchOp::getIntrinsicID(op);
2606-
createIntrinsicCall(builder, intId, $addr);
2617+
static NVVM::IDArgPair
2618+
getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,LLVM::ModuleTranslation &mt,
2619+
llvm::IRBuilderBase &builder);
2620+
bool hasIntrinsic() { return !getPredicate() || !getTensormap(); }
26072621
}];
2608-
}
2609-
2610-
def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
2611-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
2612-
Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
2613-
let assemblyFormat = "$tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
26142622
let extraClassDefinition = [{
2615-
std::string $cppClass::getPtx() {
2623+
std::string $cppClass::getPtx() {
2624+
// Inline PTX is only supported for prefetch tensormap
26162625
return std::string("prefetch.tensormap [%0];");
26172626
}
26182627
}];
2628+
let llvmBuilder = [{
2629+
auto [id, args] = NVVM::PrefetchOp::getIntrinsicIDAndArgs(op,
2630+
moduleTranslation, builder);
2631+
2632+
if(op.getTensormap())
2633+
// Overloaded intrinsic
2634+
createIntrinsicCall(builder, id, args, {args[0]->getType()});
2635+
else
2636+
createIntrinsicCall(builder, id, args);
2637+
}];
26192638
}
26202639

26212640
def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,8 +1695,10 @@ struct NVGPUTmaPrefetchOpLowering
16951695
LogicalResult
16961696
matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
16971697
ConversionPatternRewriter &rewriter) const override {
1698-
rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1699-
op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1698+
rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1699+
op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
1700+
adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1701+
/* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
17001702
return success();
17011703
}
17021704
};

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 99 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/IR/IRBuilder.h"
3434
#include "llvm/Support/Casting.h"
3535
#include "llvm/Support/FormatVariadic.h"
36+
#include "llvm/Support/NVPTXAddrSpace.h"
3637
#include "llvm/Support/raw_ostream.h"
3738
#include <cassert>
3839
#include <optional>
@@ -1332,30 +1333,70 @@ LogicalResult NVVM::PrefetchOp::verify() {
13321333
unsigned addressSpace =
13331334
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
13341335
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1336+
std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
13351337

1336-
if (getUniform()) {
1337-
if (getCacheLevel() != CacheLevel::L1)
1338-
return emitOpError("unsupported cache level, the only supported uniform "
1339-
"cache level is L1");
1338+
if (getTensormap() && cacheLevel)
1339+
return emitOpError("cannot specify both tensormap and cache level");
13401340

1341-
if (addressSpace != MemSpace::kGenericMemorySpace)
1341+
if (getTensormap()) {
1342+
if (addressSpace != MemSpace::kGenericMemorySpace &&
1343+
addressSpace != MemSpace::kConstantMemorySpace) {
13421344
return emitOpError(
1343-
"prefetch to uniform cache requires a generic pointer");
1344-
}
1345+
"prefetch tensormap requires a generic or constant pointer");
1346+
}
13451347

1346-
if (evictPriority) {
1347-
if (getCacheLevel() != CacheLevel::L2)
1348+
if (evictPriority) {
13481349
return emitOpError(
1349-
"cache eviction priority supported only for cache level L2");
1350-
1351-
if (addressSpace != MemSpace::kGlobalMemorySpace)
1352-
return emitOpError("cache eviction priority requires a global pointer");
1350+
"prefetch tensormap does not support eviction priority");
1351+
}
13531352

1354-
if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1355-
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1353+
if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) {
13561354
return emitOpError(
1357-
"unsupported cache eviction priority, only evict_last and "
1358-
"evict_normal are supported");
1355+
"in_param_space can only be specified for a generic pointer");
1356+
}
1357+
1358+
} else if (cacheLevel) {
1359+
if (addressSpace != MemSpace::kGenericMemorySpace &&
1360+
addressSpace != MemSpace::kGlobalMemorySpace &&
1361+
addressSpace != MemSpace::kLocalMemorySpace) {
1362+
return emitOpError("prefetch to cache level requires a generic, global, "
1363+
"or local pointer");
1364+
}
1365+
1366+
if (getUniform()) {
1367+
if (*cacheLevel != CacheLevel::L1) {
1368+
return emitOpError(
1369+
"unsupported cache level, the only supported uniform "
1370+
"cache level is L1");
1371+
}
1372+
1373+
if (addressSpace != MemSpace::kGenericMemorySpace) {
1374+
return emitOpError(
1375+
"prefetch to uniform cache requires a generic pointer");
1376+
}
1377+
}
1378+
1379+
if (evictPriority) {
1380+
if (*cacheLevel != CacheLevel::L2)
1381+
return emitOpError(
1382+
"cache eviction priority supported only for cache level L2");
1383+
1384+
if (addressSpace != MemSpace::kGlobalMemorySpace)
1385+
return emitOpError("cache eviction priority requires a global pointer");
1386+
1387+
if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1388+
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1389+
return emitOpError(
1390+
"unsupported cache eviction priority, only evict_last and "
1391+
"evict_normal are supported");
1392+
}
1393+
1394+
if (getPredicate())
1395+
return emitOpError("predicate supported only on prefetch tensormap");
1396+
1397+
} else {
1398+
return emitOpError(
1399+
"requires specification of either cache level or tensormap");
13591400
}
13601401

13611402
return success();
@@ -1964,43 +2005,69 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
19642005
return {ids[type], args};
19652006
}
19662007

1967-
llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
2008+
static llvm::Value *getParamCastedAddr(llvm::Value *addr,
2009+
llvm::IRBuilderBase &builder) {
2010+
return builder.CreateAddrSpaceCast(
2011+
addr,
2012+
llvm::PointerType::get(builder.getContext(),
2013+
llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
2014+
}
2015+
2016+
NVVM::IDArgPair
2017+
PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
2018+
LLVM::ModuleTranslation &mt,
2019+
llvm::IRBuilderBase &builder) {
19682020
using MemSpace = NVVM::NVVMMemorySpace;
19692021
using CacheLevel = NVVM::PrefetchCacheLevel;
19702022

1971-
NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
2023+
std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
19722024
std::optional<NVVM::CacheEvictionPriority> evictPriority =
19732025
op.getEvictPriority();
19742026
unsigned addressSpace =
19752027
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
19762028
.getAddressSpace();
19772029

1978-
if (op.getUniform() && cacheLevel == CacheLevel::L1)
1979-
return llvm::Intrinsic::nvvm_prefetchu_L1;
2030+
llvm::SmallVector<llvm::Value *> args;
2031+
llvm::Value *addr = mt.lookupValue(op.getAddr());
2032+
args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
2033+
: addr);
2034+
2035+
if (op.getTensormap())
2036+
return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
2037+
2038+
assert(cacheLevel && "expected cache level for non-tensormap prefetch");
2039+
2040+
if (op.getUniform() && *cacheLevel == CacheLevel::L1)
2041+
return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
19802042

1981-
if (evictPriority && cacheLevel == CacheLevel::L2) {
2043+
if (evictPriority && *cacheLevel == CacheLevel::L2) {
19822044
switch (*evictPriority) {
19832045
case NVVM::CacheEvictionPriority::EvictLast:
1984-
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
2046+
return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
19852047
case NVVM::CacheEvictionPriority::EvictNormal:
1986-
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
2048+
return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
19872049
default:
19882050
llvm_unreachable("Invalid cache eviction priority");
19892051
}
19902052
}
19912053

19922054
switch (addressSpace) {
19932055
case MemSpace::kGenericMemorySpace:
1994-
return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1995-
: llvm::Intrinsic::nvvm_prefetch_L2;
2056+
return *cacheLevel == CacheLevel::L1
2057+
? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
2058+
: NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
19962059
case MemSpace::kGlobalMemorySpace:
1997-
return cacheLevel == CacheLevel::L1
1998-
? llvm::Intrinsic::nvvm_prefetch_global_L1
1999-
: llvm::Intrinsic::nvvm_prefetch_global_L2;
2060+
return *cacheLevel == CacheLevel::L1
2061+
? NVVM::IDArgPair(
2062+
{llvm::Intrinsic::nvvm_prefetch_global_L1, args})
2063+
: NVVM::IDArgPair(
2064+
{llvm::Intrinsic::nvvm_prefetch_global_L2, args});
20002065
case MemSpace::kLocalMemorySpace:
2001-
return cacheLevel == CacheLevel::L1
2002-
? llvm::Intrinsic::nvvm_prefetch_local_L1
2003-
: llvm::Intrinsic::nvvm_prefetch_local_L2;
2066+
return *cacheLevel == CacheLevel::L1
2067+
? NVVM::IDArgPair(
2068+
{llvm::Intrinsic::nvvm_prefetch_local_L1, args})
2069+
: NVVM::IDArgPair(
2070+
{llvm::Intrinsic::nvvm_prefetch_local_L2, args});
20042071
default:
20052072
llvm_unreachable("Invalid pointer address space");
20062073
}

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,9 +817,9 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
817817
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>, %[[arg1:[a-zA-Z0-9_]+]]: i1
818818
func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
819819
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none> to !llvm.ptr
820-
// CHECK: nvvm.prefetch.tensormap %[[S0]] : !llvm.ptr
820+
// CHECK: nvvm.prefetch tensormap, %[[S0]] : !llvm.ptr
821821
nvgpu.tma.prefetch.descriptor %tensorMap1d: !tensorMap1d
822-
// CHECK: nvvm.prefetch.tensormap %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
822+
// CHECK: nvvm.prefetch tensormap, %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
823823
nvgpu.tma.prefetch.descriptor %tensorMap1d, predicate = %p: !tensorMap1d
824824
func.return
825825
}

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,10 +572,10 @@ func.func @elect_one_leader_sync() {
572572

573573
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
574574
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
575-
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
576-
nvvm.prefetch.tensormap %desc : !llvm.ptr
575+
//CHECK: nvvm.prefetch tensormap, %{{.*}}
576+
nvvm.prefetch tensormap, %desc : !llvm.ptr
577577
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
578-
nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
578+
nvvm.prefetch tensormap, %desc, predicate = %pred : !llvm.ptr, i1
579579
llvm.return
580580
}
581581

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c:
586586
}
587587

588588
// CHECK-LABEL: @prefetch
589-
func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
589+
func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>, %const_ptr: !llvm.ptr<4>) {
590590
// CHECK: nvvm.prefetch level = L1, %{{.*}}
591591
nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0>
592592
// CHECK: nvvm.prefetch level = L1, %{{.*}}
@@ -599,12 +599,24 @@ func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr:
599599
nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5>
600600
// CHECK: nvvm.prefetch level = L2, %{{.*}}
601601
nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1>
602-
// CHECK: nvvm.prefetch level = L2, %{{.*}}
603-
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
604-
// CHECK: nvvm.prefetch level = L2, %{{.*}}
605-
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
602+
// CHECK: nvvm.prefetch level = L2, evict_priority = evict_last, %{{.*}}
603+
nvvm.prefetch level = L2, evict_priority = evict_last, %global_ptr :
604+
!llvm.ptr<1>
605+
// CHECK: nvvm.prefetch level = L2, evict_priority = evict_normal, %{{.*}}
606+
nvvm.prefetch level = L2, evict_priority = evict_normal, %global_ptr : !llvm.ptr<1>
606607
// CHECK: nvvm.prefetch level = L1 uniform, %{{.*}}
607608
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
609+
// CHECK: nvvm.prefetch tensormap, %{{.*}}
610+
nvvm.prefetch tensormap, %gen_ptr : !llvm.ptr
611+
// CHECK: nvvm.prefetch tensormap, %{{.*}}
612+
nvvm.prefetch tensormap, %const_ptr : !llvm.ptr<4>
613+
// CHECK: nvvm.prefetch tensormap in_param_space, %{{.*}}
614+
nvvm.prefetch tensormap in_param_space, %gen_ptr : !llvm.ptr
615+
return
616+
}
617+
618+
// CHECK-LABEL: @prefetch_tensormap
619+
func.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
608620
return
609621
}
610622

mlir/test/Target/LLVMIR/nvvm/prefetch.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
3232
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
3333
// CHECK-NEXT: ret void
3434
// CHECK-NEXT: }
35-
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
36-
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
35+
nvvm.prefetch level = L2, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
36+
nvvm.prefetch level = L2, evict_priority = evict_normal, %global_ptr : !llvm.ptr<1>
3737
llvm.return
3838
}
3939

@@ -45,3 +45,17 @@ llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
4545
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
4646
llvm.return
4747
}
48+
49+
llvm.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
50+
// CHECK-LABEL: define void @prefetch_tensormap(ptr %0, ptr addrspace(4) %1) {
51+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p0(ptr %0)
52+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p4(ptr addrspace(4) %1)
53+
// CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(101)
54+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p101(ptr addrspace(101) %3)
55+
// CHECK-NEXT: ret void
56+
// CHECK-NEXT: }
57+
nvvm.prefetch tensormap, %gen_ptr : !llvm.ptr
58+
nvvm.prefetch tensormap, %const_ptr: !llvm.ptr<4>
59+
nvvm.prefetch tensormap in_param_space, %gen_ptr : !llvm.ptr
60+
llvm.return
61+
}

0 commit comments

Comments
 (0)