Skip to content

Commit 1a454b0

Browse files
committed
Introduce "stackable" resources for to improve how we nest contexts
1 parent f0acf73 commit 1a454b0

File tree

4 files changed

+512
-0
lines changed

4 files changed

+512
-0
lines changed

cudax/examples/stf/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ set(stf_example_sources
1111
08-cub-reduce.cu
1212
axpy-annotated.cu
1313
binary_fhe.cu
14+
binary_fhe_stackable.cu
1415
cfd.cu
1516
custom_data_interface.cu
1617
void_data_interface.cu
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of CUDASTF in CUDA C++ Core Libraries,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
/**
12+
* @file
13+
* @brief A toy example to illustrate how we can compose logical operations
14+
* over encrypted data
15+
*/
16+
17+
#include "cuda/experimental/__stf/utility/stackable_ctx.cuh"
18+
#include "cuda/experimental/stf.cuh"
19+
20+
using namespace cuda::experimental::stf;
21+
22+
class ciphertext;
23+
24+
class plaintext {
25+
public:
26+
plaintext(const stackable_ctx& ctx) : ctx(ctx) {}
27+
28+
plaintext(stackable_ctx& ctx, std::vector<char> v) : values(v), ctx(ctx) {
29+
l = ctx.logical_data(&values[0], values.size());
30+
}
31+
32+
void set_symbol(std::string s) {
33+
l.set_symbol(s);
34+
symbol = s;
35+
}
36+
37+
std::string get_symbol() const { return symbol; }
38+
39+
std::string symbol;
40+
41+
const stackable_logical_data<slice<char>>& data() const { return l; }
42+
43+
stackable_logical_data<slice<char>>& data() { return l; }
44+
45+
// This will asynchronously fill string s
46+
void convert_to_vector(std::vector<char>& v) {
47+
ctx.host_launch(l.read()).set_symbol("to_vector")->*[&](auto dl) {
48+
v.resize(dl.size());
49+
for (size_t i = 0; i < dl.size(); i++) {
50+
v[i] = dl(i);
51+
}
52+
};
53+
}
54+
55+
ciphertext encrypt() const;
56+
57+
stackable_logical_data<slice<char>> l;
58+
59+
template <typename... Pack>
60+
void push(Pack&&... pack) {
61+
l.push(::std::forward<Pack>(pack)...);
62+
}
63+
64+
void pop() { l.pop(); }
65+
66+
private:
67+
std::vector<char> values;
68+
mutable stackable_ctx ctx;
69+
};
70+
71+
class ciphertext {
72+
public:
73+
ciphertext() = default;
74+
75+
ciphertext(const stackable_ctx& ctx) : ctx(ctx) {}
76+
77+
plaintext decrypt() const {
78+
plaintext p(ctx);
79+
p.l = ctx.logical_data(shape_of<slice<char>>(l.shape().size()));
80+
// fprintf(stderr, "Decrypting...\n");
81+
ctx.parallel_for(l.shape(), l.read(), p.l.write()).set_symbol("decrypt")->*
82+
[] __device__ (size_t i, auto dctxt, auto dptxt) {
83+
dptxt(i) = char((dctxt(i) >> 32));
84+
// printf("DECRYPT %ld : %lx -> %x\n", i, dctxt(i), (int) dptxt(i));
85+
};
86+
return p;
87+
}
88+
89+
// Copy assignment operator
90+
ciphertext& operator=(const ciphertext& other) {
91+
if (this != &other) {
92+
fprintf(stderr, "COPY ASSIGNMENT OP... this->l.depth() %ld other.l.depth() %ld - ctx depth %ld other.ctx.depth %ld\n", l.depth(), other.l.depth(), ctx.depth(), other.ctx.depth());
93+
// l = ctx.logical_data(other.data().shape());
94+
assert(l.shape() == other.l.shape());
95+
other.ctx.parallel_for(l.shape(), other.l.read(), l.write()).set_symbol("copy")->*
96+
[] __device__ (size_t i, auto other, auto result) { result(i) = other(i); };
97+
}
98+
return *this;
99+
}
100+
101+
ciphertext operator|(const ciphertext& other) const {
102+
ciphertext result(ctx);
103+
result.l = ctx.logical_data(data().shape());
104+
105+
ctx.parallel_for(data().shape(), data().read(), other.data().read(), result.data().write()).set_symbol("OR")->*
106+
[] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) { d_res(i) = d_c1(i) | d_c2(i); };
107+
108+
return result;
109+
}
110+
111+
ciphertext operator&(const ciphertext& other) const {
112+
ciphertext result(ctx);
113+
result.l = ctx.logical_data(data().shape());
114+
115+
ctx.parallel_for(data().shape(), data().read(), other.data().read(), result.data().write()).set_symbol("AND")->*
116+
[] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) { d_res(i) = d_c1(i) & d_c2(i); };
117+
118+
return result;
119+
}
120+
121+
ciphertext operator~() const {
122+
ciphertext result(ctx);
123+
result.l = ctx.logical_data(data().shape());
124+
ctx.parallel_for(data().shape(), data().read(), result.data().write()).set_symbol("NOT")->*
125+
[] __device__(size_t i, auto d_c, auto d_res) { d_res(i) = ~d_c(i); };
126+
127+
return result;
128+
}
129+
130+
const stackable_logical_data<slice<uint64_t>>& data() const { return l; }
131+
132+
stackable_logical_data<slice<uint64_t>>& data() { return l; }
133+
134+
stackable_logical_data<slice<uint64_t>> l;
135+
136+
template <typename... Pack>
137+
void push(Pack&&... pack) {
138+
l.push(::std::forward<Pack>(pack)...);
139+
}
140+
141+
void pop() { l.pop(); }
142+
143+
private:
144+
mutable stackable_ctx ctx;
145+
};
146+
147+
ciphertext plaintext::encrypt() const {
148+
ciphertext c(ctx);
149+
c.l = ctx.logical_data(shape_of<slice<uint64_t>>(l.shape().size()));
150+
151+
ctx.parallel_for(l.shape(), l.read(), c.l.write()).set_symbol("encrypt")->*
152+
[] __device__(size_t i, auto dptxt, auto dctxt) {
153+
// A super safe encryption !
154+
dctxt(i) = ((uint64_t) (dptxt(i)) << 32 | 0x4);
155+
};
156+
157+
return c;
158+
}
159+
160+
template <typename T>
161+
T circuit(const T& a, const T& b) {
162+
return (~((a | ~b) & (~a | b)));
163+
}
164+
165+
int main() {
166+
stackable_ctx ctx;
167+
168+
std::vector<char> vA { 3, 3, 2, 2, 17 };
169+
plaintext pA(ctx, vA);
170+
pA.set_symbol("A");
171+
172+
std::vector<char> vB { 1, 7, 7, 7, 49 };
173+
plaintext pB(ctx, vB);
174+
pB.set_symbol("B");
175+
176+
auto eA = pA.encrypt();
177+
auto eB = pB.encrypt();
178+
179+
ctx.push_graph();
180+
181+
eA.push(access_mode::read);
182+
eB.push(access_mode::read);
183+
184+
// TODO find a way to get "out" outside of this scope to do decryption in the main ctx
185+
auto out = circuit(eA, eB);
186+
187+
std::vector<char> v_out;
188+
out.decrypt().convert_to_vector(v_out);
189+
190+
eA.pop();
191+
eB.pop();
192+
193+
ctx.pop();
194+
195+
ctx.finalize();
196+
197+
for (size_t i = 0; i < v_out.size(); i++) {
198+
char expected = circuit(vA[i], vB[i]);
199+
EXPECT(expected == v_out[i]);
200+
}
201+
}

0 commit comments

Comments
 (0)