Skip to content

Commit 6ea923b

Browse files
committed
Append cache controls instead of overwriting.
1 parent 61de220 commit 6ea923b

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

llvm/lib/SYCLLowerIR/CompileTimePropertiesPass.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,10 @@ bool CompileTimePropertiesPass::transformSYCLPropertiesAnnotation(
950950
if (!FPGAProp && llvm::isa<llvm::Instruction>(IntrInst->getArgOperand(0))) {
951951
// If there are no annotations other than cache controls we can apply the
952952
// controls to the pointer and remove the intrinsic.
953-
auto PtrInstr = cast<Instruction>(IntrInst->getArgOperand(0));
953+
Instruction *PtrInstr = cast<Instruction>(IntrInst->getArgOperand(0));
954+
if (MDNode *CurrentMD = PtrInstr->getMetadata(MDKindID))
955+
for (Metadata *Op : CurrentMD->operands())
956+
MDOpsCacheProp.push_back(Op);
954957
PtrInstr->setMetadata(MDKindID, MDTuple::get(Ctx, MDOpsCacheProp));
955958
// Replace all uses of IntrInst with first operand
956959
IntrInst->replaceAllUsesWith(PtrInstr);

sycl/test/check_device_code/extensions/properties/properties_cache_control.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ using namespace sycl;
77
using namespace ext::oneapi::experimental;
88
using namespace ext::intel::experimental;
99

10+
using ST_L1 = cache_control<cache_mode::streaming, cache_level::L1>;
11+
using WB_L1 = cache_control<cache_mode::write_back, cache_level::L1>;
12+
using UC_L1 = cache_control<cache_mode::uncached, cache_level::L1>;
13+
using CA_L1 = cache_control<cache_mode::cached, cache_level::L1>;
14+
using UC_L2 = cache_control<cache_mode::uncached, cache_level::L2>;
15+
using CA_L2 = cache_control<cache_mode::cached, cache_level::L2>;
16+
using UC_L3 = cache_control<cache_mode::uncached, cache_level::L3>;
17+
using CA_L3 = cache_control<cache_mode::cached, cache_level::L3>;
18+
1019
using load_hint = annotated_ptr<
1120
float, decltype(properties(
1221
read_hint<cache_control<cache_mode::cached, cache_level::L1>,
@@ -33,6 +42,22 @@ using load_store_hint = annotated_ptr<
3342
write_hint<
3443
cache_control<cache_mode::write_through, cache_level::L4>>))>;
3544

45+
template <typename t>
46+
using ap_load_ca_uc_uc =
47+
annotated_ptr<t, decltype(properties(read_hint<CA_L1, UC_L2, UC_L3>))>;
48+
49+
template <typename t>
50+
using ap_load_st_ca_uc =
51+
annotated_ptr<t, decltype(properties(read_hint<ST_L1, CA_L2, CA_L3>))>;
52+
53+
template <typename T>
54+
using ap_store_uc_uc_uc =
55+
annotated_ptr<T, decltype(properties(write_hint<UC_L1, UC_L2, UC_L3>))>;
56+
57+
template <typename T>
58+
using ap_store_wb_uc_uc =
59+
annotated_ptr<T, decltype(properties(write_hint<WB_L1, UC_L2, UC_L3>))>;
60+
3661
void cache_control_read_hint_func() {
3762
queue q;
3863
constexpr int N = 10;
@@ -81,6 +106,65 @@ void cache_control_read_write_func() {
81106
});
82107
}
83108

109+
void cache_control_load_store_func() {
110+
queue q(gpu_selector_v);
111+
112+
constexpr int N = 512;
113+
double *x_buf = malloc_device<double>(N, q);
114+
double *y_buf = malloc_device<double>(N, q);
115+
double *d_buf = malloc_device<double>(1, q);
116+
double *d_buf_h = malloc_host<double>(1, q);
117+
118+
q.fill<double>(d_buf, 0.0, 1).wait();
119+
120+
constexpr int SG_SIZE = 16;
121+
122+
q.submit([&](handler &cgh) {
123+
const int nwg = N / SG_SIZE;
124+
auto x = x_buf;
125+
auto y = y_buf;
126+
auto d = d_buf;
127+
auto d_h = d_buf_h;
128+
129+
auto kernel =
130+
[=](nd_item<2> item) [[intel::reqd_sub_group_size(SG_SIZE)]] {
131+
const int global_tid = item.get_global_id(0);
132+
const int row_st = global_tid * SG_SIZE;
133+
134+
if (row_st > N)
135+
return;
136+
137+
const sub_group sgr = item.get_sub_group();
138+
const int sgr_tid = sgr.get_local_id();
139+
140+
ap_store_uc_uc_uc<double> x_s(x);
141+
ap_store_wb_uc_uc<double> y_s(y);
142+
143+
x_s[row_st + sgr_tid] = 1.0;
144+
y_s[row_st + sgr_tid] = 1.0;
145+
146+
group_barrier(sgr);
147+
148+
ap_load_ca_uc_uc<double> x_l(x);
149+
ap_load_st_ca_uc<double> y_l(y);
150+
151+
const double xVal = x_l[row_st + sgr_tid];
152+
const double yVal = y_l[row_st + sgr_tid];
153+
double T = xVal * yVal;
154+
T = reduce_over_group(sgr, T, 0.0, std::plus<>());
155+
156+
if (sgr.leader()) {
157+
atomic_ref<double, memory_order::relaxed, memory_scope::device,
158+
access::address_space::global_space>
159+
d_atomic(d[0]);
160+
d_atomic.fetch_add(T);
161+
}
162+
};
163+
cgh.parallel_for<class write_ker>(
164+
nd_range<2>(range<2>(nwg, SG_SIZE), range<2>(1, SG_SIZE)), kernel);
165+
}).wait();
166+
}
167+
84168
// Test that annotated pointer parameter functions don't crash.
85169
SYCL_EXTERNAL void annotated_ptr_func_param_test(float *p) {
86170
*(store_hint{p}) = 42.0f;
@@ -106,6 +190,11 @@ SYCL_EXTERNAL void annotated_ptr_func_param_test(float *p) {
106190
// CHECK-IR: {{.*}}addrspacecast ptr addrspace(1){{.*}}!spirv.Decorations [[RWHINT:.*]]
107191
// CHECK-IR: ret void
108192

193+
// CHECK-IR: spir_kernel{{.*}}cache_control_load_store_func
194+
// CHECK-IR: {{.*}}getelementptr{{.*}}addrspace(4){{.*}}!spirv.Decorations [[LDSTHINT_A:.*]]
195+
// CHECK-IR: {{.*}}getelementptr{{.*}}addrspace(4){{.*}}!spirv.Decorations [[LDSTHINT_B:.*]]
196+
// CHECK-IR: ret void
197+
109198
// CHECK-IR: [[WHINT]] = !{[[WHINT1:.*]], [[WHINT2:.*]], [[WHINT3:.*]], [[WHINT4:.*]]}
110199
// CHECK-IR: [[WHINT1]] = !{i32 6443, i32 3, i32 3}
111200
// CHECK-IR: [[WHINT2]] = !{i32 6443, i32 0, i32 1}
@@ -126,3 +215,13 @@ SYCL_EXTERNAL void annotated_ptr_func_param_test(float *p) {
126215
// CHECK-IR: [[RWHINT1]] = !{i32 6442, i32 2, i32 1}
127216
// CHECK-IR: [[RWHINT2]] = !{i32 6442, i32 3, i32 4}
128217
// CHECK-IR: [[RWHINT3]] = !{i32 6443, i32 3, i32 1}
218+
219+
// CHECK-IR: [[LDSTHINT_A]] = !{[[RHINT1]], [[RHINT2]], [[RHINT3]], [[LDSTHINT_A1:.*]], [[LDSTHINT_A2:.*]], [[LDSTHINT_A3:.*]]}
220+
// CHECK-IR: [[LDSTHINT_A1]] = !{i32 6443, i32 0, i32 0}
221+
// CHECK-IR: [[LDSTHINT_A2]] = !{i32 6443, i32 1, i32 0}
222+
// CHECK-IR: [[LDSTHINT_A3]] = !{i32 6443, i32 2, i32 0}
223+
224+
// CHECK-IR: [[LDSTHINT_B]] = !{[[LDSTHINT_B1:.*]], [[RWHINT1]], [[LDSTHINT_B2:.*]], [[LDSTHINT_A2]], [[LDSTHINT_A3]], [[LDSTHINT_B3:.*]]}
225+
// CHECK-IR: [[LDSTHINT_B1]] = !{i32 6442, i32 1, i32 1}
226+
// CHECK-IR: [[LDSTHINT_B2]] = !{i32 6442, i32 0, i32 2}
227+
// CHECK-IR: [[LDSTHINT_B3]] = !{i32 6443, i32 0, i32 2}

0 commit comments

Comments
 (0)