Skip to content

Commit 779153f

Browse files
authored
Merge branch 'main' into stf_c_api
2 parents dc78d7d + ab58dd0 commit 779153f

File tree

62 files changed

+10061
-245
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+10061
-245
lines changed

cudax/examples/stf/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,42 @@ set(
2121
01-axpy-launch.cu
2222
01-axpy-parallel_for.cu
2323
binary_fhe.cu
24+
binary_fhe_stackable.cu
2425
09-dot-reduce.cu
2526
cfd.cu
2627
custom_data_interface.cu
2728
fdtd_mgpu.cu
29+
fdtd_while.cu
30+
fdtd_repeat_n.cu
2831
frozen_data_init.cu
2932
graph_algorithms/degree_centrality.cu
3033
graph_algorithms/jaccard.cu
3134
graph_algorithms/pagerank.cu
35+
graph_algorithms/pagerank_batched.cu
36+
graph_algorithms/pagerank_while.cu
3237
graph_algorithms/tricount.cu
38+
graph_scope.cu
3339
heat.cu
3440
heat_mgpu.cu
3541
jacobi.cu
3642
jacobi_pfor.cu
43+
jacobi_stackable.cu
44+
jacobi_stackable_raii.cu
45+
jacobi_update_cond.cu
3746
launch_histogram.cu
3847
launch_scan.cu
3948
launch_sum.cu
4049
launch_sum_cub.cu
50+
linear_algebra/burger.cu
51+
linear_algebra/burger_sensitivity.cu
52+
linear_algebra/cg_csr.cu
53+
linear_algebra/cg_csr_stackable.cu
4154
logical_gates_composition.cu
4255
mandelbrot.cu
4356
parallel_for_2D.cu
4457
pi.cu
4558
scan.cu
59+
sqrt_newton_stackable.cu
4660
standalone-launches.cu
4761
word_count.cu
4862
word_count_reduce.cu
@@ -52,9 +66,9 @@ set(
5266
set(
5367
stf_example_mathlib_sources
5468
linear_algebra/06-pdgemm.cu
69+
linear_algebra/06-pdgemm-stackable.cu
5570
linear_algebra/07-cholesky.cu
5671
linear_algebra/07-potri.cu
57-
linear_algebra/cg_csr.cu
5872
linear_algebra/cg_dense_2D.cu
5973
linear_algebra/strassen.cu
6074
)
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of CUDASTF in CUDA C++ Core Libraries,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
/**
12+
* @file
13+
* @brief A toy example to illustrate how we can compose logical operations over encrypted data
14+
*/
15+
16+
#include <cuda/experimental/stf.cuh>
17+
18+
using namespace cuda::experimental::stf;
19+
20+
#include <memory>
21+
22+
class ciphertext;
23+
24+
class plaintext
25+
{
26+
public:
27+
plaintext(const stackable_ctx& ctx)
28+
: ctx(ctx)
29+
{}
30+
31+
plaintext(stackable_ctx& ctx, ::std::vector<char> v)
32+
: values(mv(v))
33+
, ctx(ctx)
34+
, ld(ctx.logical_data(values.data(), values.size()))
35+
{}
36+
37+
auto& set_symbol(const std::string& s)
38+
{
39+
ld.set_symbol(s);
40+
symbol = s;
41+
42+
return *this;
43+
}
44+
45+
const std::string& get_symbol() const
46+
{
47+
return symbol;
48+
}
49+
50+
// This will asynchronously fill string s
51+
void convert_to_vector(std::vector<char>& v)
52+
{
53+
ctx.host_launch(ld.read()).set_symbol("to_vector")->*[&](auto dl) {
54+
v.resize(dl.size());
55+
for (size_t i = 0; i < dl.size(); i++)
56+
{
57+
v[i] = dl(i);
58+
}
59+
};
60+
}
61+
62+
ciphertext encrypt() const;
63+
64+
private:
65+
std::vector<char> values;
66+
mutable stackable_ctx ctx;
67+
::std::string symbol;
68+
69+
public:
70+
mutable stackable_logical_data<slice<char>> ld;
71+
};
72+
73+
class ciphertext
74+
{
75+
public:
76+
ciphertext() = default;
77+
78+
// We need a deep-copy semantic
79+
ciphertext(const ciphertext& other)
80+
: ctx(other.ctx)
81+
, symbol(other.symbol)
82+
{
83+
copy_content(ctx, other, *this);
84+
}
85+
86+
ciphertext(const stackable_ctx& ctx)
87+
: ctx(ctx)
88+
{}
89+
90+
ciphertext(ciphertext&&) = default;
91+
ciphertext& operator=(ciphertext&&) = default;
92+
93+
static void copy_content(stackable_ctx& ctx, const ciphertext& src, ciphertext& dst)
94+
{
95+
dst.ld = ctx.logical_data(src.ld.shape());
96+
ctx.parallel_for(src.ld.shape(), src.ld.read(), dst.ld.write()).set_symbol("copy")->*
97+
[] __device__(size_t i, auto src, auto dst) {
98+
dst(i) = src(i);
99+
};
100+
}
101+
102+
auto& set_symbol(std::string s)
103+
{
104+
ld.set_symbol(s);
105+
symbol = mv(s);
106+
107+
return *this;
108+
}
109+
110+
const std::string& get_symbol() const
111+
{
112+
return symbol;
113+
}
114+
115+
plaintext decrypt() const
116+
{
117+
plaintext p(ctx);
118+
p.ld = ctx.logical_data(shape_of<slice<char>>(ld.shape().size()));
119+
ctx.parallel_for(ld.shape(), ld.read(), p.ld.write()).set_symbol("decrypt")->*
120+
[] __device__(size_t i, auto cipher_data, auto plain_data) {
121+
plain_data(i) = static_cast<char>(cipher_data(i) >> 32);
122+
};
123+
return p;
124+
}
125+
126+
// Copy assignment operator
127+
// We need a deep-copy semantic
128+
ciphertext& operator=(const ciphertext& other)
129+
{
130+
if (this != &other)
131+
{
132+
ctx = other.ctx;
133+
symbol = other.symbol;
134+
copy_content(ctx, other, *this);
135+
}
136+
return *this;
137+
}
138+
139+
ciphertext operator|(const ciphertext& other) const
140+
{
141+
ciphertext result(ctx);
142+
result.ld = ctx.logical_data(ld.shape());
143+
144+
ctx.parallel_for(ld.shape(), ld.read(), other.ld.read(), result.ld.write()).set_symbol("OR")->*
145+
[] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) {
146+
d_res(i) = d_c1(i) | d_c2(i);
147+
};
148+
149+
return result;
150+
}
151+
152+
ciphertext operator&(const ciphertext& other) const
153+
{
154+
ciphertext result(ctx);
155+
result.ld = ctx.logical_data(ld.shape());
156+
157+
ctx.parallel_for(ld.shape(), ld.read(), other.ld.read(), result.ld.write()).set_symbol("AND")->*
158+
[] __device__(size_t i, auto d_c1, auto d_c2, auto d_res) {
159+
d_res(i) = d_c1(i) & d_c2(i);
160+
};
161+
162+
return result;
163+
}
164+
165+
ciphertext operator~() const
166+
{
167+
ciphertext result(ctx);
168+
result.ld = ctx.logical_data(ld.shape());
169+
170+
ctx.parallel_for(ld.shape(), ld.read(), result.ld.write()).set_symbol("NOT")->*
171+
[] __device__(size_t i, auto d_c, auto d_res) {
172+
d_res(i) = ~d_c(i);
173+
};
174+
175+
return result;
176+
}
177+
178+
mutable stackable_logical_data<slice<uint64_t>> ld;
179+
180+
private:
181+
mutable stackable_ctx ctx;
182+
::std::string symbol;
183+
};
184+
185+
ciphertext plaintext::encrypt() const
186+
{
187+
ciphertext c(ctx);
188+
c.ld = ctx.logical_data(shape_of<slice<uint64_t>>(ld.shape().size()));
189+
190+
ctx.parallel_for(ld.shape(), ld.read(), c.ld.write()).set_symbol("encrypt")->*
191+
[] __device__(size_t i, auto dptxt, auto dctxt) {
192+
// A super safe encryption !
193+
dctxt(i) = ((uint64_t) (dptxt(i)) << 32 | 0x4);
194+
};
195+
196+
return c;
197+
}
198+
199+
template <typename T>
200+
T circuit(const T& a, const T& b)
201+
{
202+
return ~((a | ~b) & (~a | b));
203+
}
204+
205+
int main()
206+
{
207+
stackable_ctx ctx;
208+
209+
const std::vector<char> vA{3, 3, 2, 2, 17};
210+
plaintext pA(ctx, std::vector<char>(vA));
211+
pA.set_symbol("A");
212+
213+
const std::vector<char> vB{1, 7, 7, 7, 49};
214+
plaintext pB(ctx, std::vector<char>(vB));
215+
pB.set_symbol("B");
216+
217+
auto s_encrypt = ctx.dot_section("encrypt");
218+
auto eA = pA.encrypt().set_symbol("A");
219+
auto eB = pB.encrypt().set_symbol("B");
220+
s_encrypt.end();
221+
222+
ctx.push();
223+
224+
auto s_circuit = ctx.dot_section("circuit");
225+
auto out = circuit(eA, eB);
226+
s_circuit.end();
227+
228+
ctx.pop();
229+
230+
std::vector<char> v_out;
231+
out.decrypt().convert_to_vector(v_out);
232+
233+
ctx.finalize();
234+
235+
for (size_t i = 0; i < v_out.size(); i++)
236+
{
237+
char expected = circuit(vA[i], vB[i]);
238+
EXPECT(expected == v_out[i]);
239+
}
240+
}

0 commit comments

Comments
 (0)