Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,11 @@ harness = false
[[bench]]
name = "bench_matrix_mul_gpu"
harness = false

[[bench]]
name = "bench_preimage_cpu"
harness = false

[[bench]]
name = "bench_preimage_gpu"
harness = false
37 changes: 37 additions & 0 deletions benches/bench_preimage_cpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use mxx::{
poly::dcrt::params::DCRTPolyParams,
sampler::{
DistType, PolyTrapdoorSampler, PolyUniformSampler, trapdoor::DCRTPolyTrapdoorSampler,
uniform::DCRTPolyUniformSampler,
},
};
use std::{hint::black_box, time::Instant};
use tracing::info;

const SIGMA: f64 = 4.578;
const TRAPDOOR_SIZE: usize = 1;
const TARGET_COLS: usize = 50;

fn bench_cpu_preimage() {
let _ = tracing_subscriber::fmt::try_init();

// Keep parameters aligned with the GPU benchmark for a fair comparison.
let params = DCRTPolyParams::new(16384, 10, 24, 12);
let trapdoor_sampler = DCRTPolyTrapdoorSampler::new(&params, SIGMA);
let uniform_sampler = DCRTPolyUniformSampler::new();

let (trapdoor, public_matrix) = trapdoor_sampler.trapdoor(&params, TRAPDOOR_SIZE);
let target =
uniform_sampler.sample_uniform(&params, TRAPDOOR_SIZE, TARGET_COLS, DistType::FinRingDist);

let start = Instant::now();
let preimage = trapdoor_sampler.preimage(&params, &trapdoor, &public_matrix, &target);
let elapsed = start.elapsed();
black_box(preimage);

info!("CPU DCRT preimage: {:?}", elapsed);
}

fn main() {
bench_cpu_preimage();
}
69 changes: 69 additions & 0 deletions benches/bench_preimage_gpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#[cfg(feature = "gpu")]
use std::{hint::black_box, time::Instant};
#[cfg(feature = "gpu")]
use tracing::info;

#[cfg(feature = "gpu")]
const SIGMA: f64 = 4.578;
#[cfg(feature = "gpu")]
const TRAPDOOR_SIZE: usize = 1;
#[cfg(feature = "gpu")]
const TARGET_COLS: usize = 50;

#[cfg(feature = "gpu")]
fn bench_gpu_preimage() {
use mxx::{
matrix::gpu_dcrt_poly::GpuDCRTPolyMatrix,
poly::{
PolyParams,
dcrt::{
gpu::{GpuDCRTPolyParams, gpu_device_sync},
params::DCRTPolyParams,
},
},
sampler::{
DistType, PolyTrapdoorSampler, PolyUniformSampler,
trapdoor::GpuDCRTPolyTrapdoorSampler, uniform::DCRTPolyUniformSampler,
},
};

gpu_device_sync();
let _ = tracing_subscriber::fmt::try_init();

// Keep parameters aligned with the CPU benchmark for a fair comparison.
let cpu_params = DCRTPolyParams::new(16384, 10, 24, 12);
let (moduli, _, _) = cpu_params.to_crt();
let params =
GpuDCRTPolyParams::new(cpu_params.ring_dimension(), moduli, cpu_params.base_bits());

let trapdoor_sampler = GpuDCRTPolyTrapdoorSampler::new(&params, SIGMA);
let uniform_sampler = DCRTPolyUniformSampler::new();

let (trapdoor, public_matrix) = trapdoor_sampler.trapdoor(&params, TRAPDOOR_SIZE);
let target_cpu = uniform_sampler.sample_uniform(
&cpu_params,
TRAPDOOR_SIZE,
TARGET_COLS,
DistType::FinRingDist,
);
let target = GpuDCRTPolyMatrix::from_cpu_matrix(&params, &target_cpu);

gpu_device_sync();
let start = Instant::now();
let preimage = trapdoor_sampler.preimage(&params, &trapdoor, &public_matrix, &target);
gpu_device_sync();
let elapsed = start.elapsed();
black_box(preimage);

info!("GPU DCRT preimage: {:?}", elapsed);
}

#[cfg(not(feature = "gpu"))]
fn main() {
println!("GPU benchmark skipped (enable with --features gpu).");
}

#[cfg(feature = "gpu")]
fn main() {
bench_gpu_preimage();
}
135 changes: 135 additions & 0 deletions cuda/GpuChaCha.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#pragma once

#include <cuda_runtime.h>

#include <stdint.h>

namespace gpu_chacha
{
struct DeviceChaChaRng
{
uint32_t state[16];
uint32_t block[16];
uint32_t block_idx;
};

__device__ __forceinline__ uint32_t rotl32(uint32_t x, uint32_t n)
{
return (x << n) | (x >> (32U - n));
}

__device__ __forceinline__ uint64_t splitmix64_next(uint64_t &state)
{
uint64_t z = (state += 0x9e3779b97f4a7c15ULL);
z = (z ^ (z >> 30U)) * 0xbf58476d1ce4e5b9ULL;
z = (z ^ (z >> 27U)) * 0x94d049bb133111ebULL;
return z ^ (z >> 31U);
}

__device__ __forceinline__ void quarter_round(
uint32_t &a,
uint32_t &b,
uint32_t &c,
uint32_t &d)
{
a += b;
d ^= a;
d = rotl32(d, 16U);

c += d;
b ^= c;
b = rotl32(b, 12U);

a += b;
d ^= a;
d = rotl32(d, 8U);

c += d;
b ^= c;
b = rotl32(b, 7U);
}

__device__ __forceinline__ void chacha20_block(
const uint32_t in_state[16],
uint32_t out_block[16])
{
uint32_t x[16];
for (uint32_t i = 0; i < 16; ++i)
{
x[i] = in_state[i];
}

for (uint32_t round = 0; round < 10; ++round)
{
quarter_round(x[0], x[4], x[8], x[12]);
quarter_round(x[1], x[5], x[9], x[13]);
quarter_round(x[2], x[6], x[10], x[14]);
quarter_round(x[3], x[7], x[11], x[15]);

quarter_round(x[0], x[5], x[10], x[15]);
quarter_round(x[1], x[6], x[11], x[12]);
quarter_round(x[2], x[7], x[8], x[13]);
quarter_round(x[3], x[4], x[9], x[14]);
}

for (uint32_t i = 0; i < 16; ++i)
{
out_block[i] = x[i] + in_state[i];
}
}

__device__ __forceinline__ void rng_init(
DeviceChaChaRng &rng,
uint64_t seed,
uint64_t stream0,
uint64_t stream1,
uint64_t stream2,
uint64_t domain_tag)
{
rng.state[0] = 0x61707865U;
rng.state[1] = 0x3320646eU;
rng.state[2] = 0x79622d32U;
rng.state[3] = 0x6b206574U;

uint64_t mix = seed ^ 0x243f6a8885a308d3ULL;
mix ^= (stream0 + 0x9e3779b97f4a7c15ULL);
mix ^= (stream1 + 0xbf58476d1ce4e5b9ULL);
mix ^= (stream2 + 0x94d049bb133111ebULL);
mix ^= (domain_tag + 0xd6e8feb86659fd93ULL);

for (uint32_t i = 0; i < 4; ++i)
{
const uint64_t v = splitmix64_next(mix);
rng.state[4 + 2 * i] = static_cast<uint32_t>(v);
rng.state[5 + 2 * i] = static_cast<uint32_t>(v >> 32U);
}

const uint64_t n0 = splitmix64_next(mix);
const uint64_t n1 = splitmix64_next(mix);
rng.state[12] = 0U;
rng.state[13] = static_cast<uint32_t>(n0);
rng.state[14] = static_cast<uint32_t>(n0 >> 32U);
rng.state[15] = static_cast<uint32_t>(n1);

rng.block_idx = 8U;
}

__device__ __forceinline__ void rng_refill(DeviceChaChaRng &rng)
{
chacha20_block(rng.state, rng.block);
rng.state[12] += 1U;
rng.block_idx = 0U;
}

__device__ __forceinline__ uint64_t rng_next_u64(DeviceChaChaRng &rng)
{
if (rng.block_idx >= 8U)
{
rng_refill(rng);
}
const uint32_t w0 = rng.block[2U * rng.block_idx];
const uint32_t w1 = rng.block[2U * rng.block_idx + 1U];
rng.block_idx += 1U;
return static_cast<uint64_t>(w0) | (static_cast<uint64_t>(w1) << 32U);
}
} // namespace gpu_chacha
Loading