diff --git a/Cargo.toml b/Cargo.toml index 8c59c48..9f8710a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,3 +65,11 @@ harness = false [[bench]] name = "bench_matrix_mul_gpu" harness = false + +[[bench]] +name = "bench_preimage_cpu" +harness = false + +[[bench]] +name = "bench_preimage_gpu" +harness = false diff --git a/benches/bench_preimage_cpu.rs b/benches/bench_preimage_cpu.rs new file mode 100644 index 0000000..b7259b3 --- /dev/null +++ b/benches/bench_preimage_cpu.rs @@ -0,0 +1,37 @@ +use mxx::{ + poly::dcrt::params::DCRTPolyParams, + sampler::{ + DistType, PolyTrapdoorSampler, PolyUniformSampler, trapdoor::DCRTPolyTrapdoorSampler, + uniform::DCRTPolyUniformSampler, + }, +}; +use std::{hint::black_box, time::Instant}; +use tracing::info; + +const SIGMA: f64 = 4.578; +const TRAPDOOR_SIZE: usize = 1; +const TARGET_COLS: usize = 50; + +fn bench_cpu_preimage() { + let _ = tracing_subscriber::fmt::try_init(); + + // Keep parameters aligned with the GPU benchmark for a fair comparison. + let params = DCRTPolyParams::new(16384, 10, 24, 12); + let trapdoor_sampler = DCRTPolyTrapdoorSampler::new(¶ms, SIGMA); + let uniform_sampler = DCRTPolyUniformSampler::new(); + + let (trapdoor, public_matrix) = trapdoor_sampler.trapdoor(¶ms, TRAPDOOR_SIZE); + let target = + uniform_sampler.sample_uniform(¶ms, TRAPDOOR_SIZE, TARGET_COLS, DistType::FinRingDist); + + let start = Instant::now(); + let preimage = trapdoor_sampler.preimage(¶ms, &trapdoor, &public_matrix, &target); + let elapsed = start.elapsed(); + black_box(preimage); + + info!("CPU DCRT preimage: {:?}", elapsed); +} + +fn main() { + bench_cpu_preimage(); +} diff --git a/benches/bench_preimage_gpu.rs b/benches/bench_preimage_gpu.rs new file mode 100644 index 0000000..a3cd551 --- /dev/null +++ b/benches/bench_preimage_gpu.rs @@ -0,0 +1,69 @@ +#[cfg(feature = "gpu")] +use std::{hint::black_box, time::Instant}; +#[cfg(feature = "gpu")] +use tracing::info; + +#[cfg(feature = "gpu")] +const SIGMA: f64 = 4.578; +#[cfg(feature = "gpu")] +const TRAPDOOR_SIZE: usize = 1; +#[cfg(feature = "gpu")] +const TARGET_COLS: usize = 50; + +#[cfg(feature = "gpu")] +fn bench_gpu_preimage() { + use mxx::{ + matrix::gpu_dcrt_poly::GpuDCRTPolyMatrix, + poly::{ + PolyParams, + dcrt::{ + gpu::{GpuDCRTPolyParams, gpu_device_sync}, + params::DCRTPolyParams, + }, + }, + sampler::{ + DistType, PolyTrapdoorSampler, PolyUniformSampler, + trapdoor::GpuDCRTPolyTrapdoorSampler, uniform::DCRTPolyUniformSampler, + }, + }; + + gpu_device_sync(); + let _ = tracing_subscriber::fmt::try_init(); + + // Keep parameters aligned with the CPU benchmark for a fair comparison. + let cpu_params = DCRTPolyParams::new(16384, 10, 24, 12); + let (moduli, _, _) = cpu_params.to_crt(); + let params = + GpuDCRTPolyParams::new(cpu_params.ring_dimension(), moduli, cpu_params.base_bits()); + + let trapdoor_sampler = GpuDCRTPolyTrapdoorSampler::new(¶ms, SIGMA); + let uniform_sampler = DCRTPolyUniformSampler::new(); + + let (trapdoor, public_matrix) = trapdoor_sampler.trapdoor(¶ms, TRAPDOOR_SIZE); + let target_cpu = uniform_sampler.sample_uniform( + &cpu_params, + TRAPDOOR_SIZE, + TARGET_COLS, + DistType::FinRingDist, + ); + let target = GpuDCRTPolyMatrix::from_cpu_matrix(¶ms, &target_cpu); + + gpu_device_sync(); + let start = Instant::now(); + let preimage = trapdoor_sampler.preimage(¶ms, &trapdoor, &public_matrix, &target); + gpu_device_sync(); + let elapsed = start.elapsed(); + black_box(preimage); + + info!("GPU DCRT preimage: {:?}", elapsed); +} + +#[cfg(not(feature = "gpu"))] +fn main() { + println!("GPU benchmark skipped (enable with --features gpu)."); +} + +#[cfg(feature = "gpu")] +fn main() { + bench_gpu_preimage(); +} diff --git a/cuda/GpuChaCha.cuh b/cuda/GpuChaCha.cuh new file mode 100644 index 0000000..74da6b4 --- /dev/null +++ b/cuda/GpuChaCha.cuh @@ -0,0 +1,135 @@ +#pragma once + +#include + +#include + +namespace gpu_chacha +{ + struct DeviceChaChaRng + { + uint32_t state[16]; + uint32_t block[16]; + uint32_t block_idx; + }; + + __device__ __forceinline__ uint32_t rotl32(uint32_t x, uint32_t n) + { + return (x << n) | (x >> (32U - n)); + } + + __device__ __forceinline__ uint64_t splitmix64_next(uint64_t &state) + { + uint64_t z = (state += 0x9e3779b97f4a7c15ULL); + z = (z ^ (z >> 30U)) * 0xbf58476d1ce4e5b9ULL; + z = (z ^ (z >> 27U)) * 0x94d049bb133111ebULL; + return z ^ (z >> 31U); + } + + __device__ __forceinline__ void quarter_round( + uint32_t &a, + uint32_t &b, + uint32_t &c, + uint32_t &d) + { + a += b; + d ^= a; + d = rotl32(d, 16U); + + c += d; + b ^= c; + b = rotl32(b, 12U); + + a += b; + d ^= a; + d = rotl32(d, 8U); + + c += d; + b ^= c; + b = rotl32(b, 7U); + } + + __device__ __forceinline__ void chacha20_block( + const uint32_t in_state[16], + uint32_t out_block[16]) + { + uint32_t x[16]; + for (uint32_t i = 0; i < 16; ++i) + { + x[i] = in_state[i]; + } + + for (uint32_t round = 0; round < 10; ++round) + { + quarter_round(x[0], x[4], x[8], x[12]); + quarter_round(x[1], x[5], x[9], x[13]); + quarter_round(x[2], x[6], x[10], x[14]); + quarter_round(x[3], x[7], x[11], x[15]); + + quarter_round(x[0], x[5], x[10], x[15]); + quarter_round(x[1], x[6], x[11], x[12]); + quarter_round(x[2], x[7], x[8], x[13]); + quarter_round(x[3], x[4], x[9], x[14]); + } + + for (uint32_t i = 0; i < 16; ++i) + { + out_block[i] = x[i] + in_state[i]; + } + } + + __device__ __forceinline__ void rng_init( + DeviceChaChaRng &rng, + uint64_t seed, + uint64_t stream0, + uint64_t stream1, + uint64_t stream2, + uint64_t domain_tag) + { + rng.state[0] = 0x61707865U; + rng.state[1] = 0x3320646eU; + rng.state[2] = 0x79622d32U; + rng.state[3] = 0x6b206574U; + + uint64_t mix = seed ^ 0x243f6a8885a308d3ULL; + mix ^= (stream0 + 0x9e3779b97f4a7c15ULL); + mix ^= (stream1 + 0xbf58476d1ce4e5b9ULL); + mix ^= (stream2 + 0x94d049bb133111ebULL); + mix ^= (domain_tag + 0xd6e8feb86659fd93ULL); + + for (uint32_t i = 0; i < 4; ++i) + { + const uint64_t v = splitmix64_next(mix); + rng.state[4 + 2 * i] = static_cast(v); + rng.state[5 + 2 * i] = static_cast(v >> 32U); + } + + const uint64_t n0 = splitmix64_next(mix); + const uint64_t n1 = splitmix64_next(mix); + rng.state[12] = 0U; + rng.state[13] = static_cast(n0); + rng.state[14] = static_cast(n0 >> 32U); + rng.state[15] = static_cast(n1); + + rng.block_idx = 8U; + } + + __device__ __forceinline__ void rng_refill(DeviceChaChaRng &rng) + { + chacha20_block(rng.state, rng.block); + rng.state[12] += 1U; + rng.block_idx = 0U; + } + + __device__ __forceinline__ uint64_t rng_next_u64(DeviceChaChaRng &rng) + { + if (rng.block_idx >= 8U) + { + rng_refill(rng); + } + const uint32_t w0 = rng.block[2U * rng.block_idx]; + const uint32_t w1 = rng.block[2U * rng.block_idx + 1U]; + rng.block_idx += 1U; + return static_cast(w0) | (static_cast(w1) << 32U); + } +} // namespace gpu_chacha diff --git a/cuda/GpuMatrix.cu b/cuda/GpuMatrix.cu index 7c10fe7..cf55da1 100644 --- a/cuda/GpuMatrix.cu +++ b/cuda/GpuMatrix.cu @@ -1,7 +1,9 @@ #include "GpuMatrix.h" +#include "GpuChaCha.cuh" #include "GpuPolyInternal.h" #include +#include #include #include #include @@ -19,6 +21,8 @@ namespace constexpr int kMatmulTileM = 16; constexpr int kMatmulTileN = 16; constexpr int kMatmulTileK = 8; + constexpr uint32_t kGaussMaxDigits = 64; + constexpr double kTwoPi = 6.283185307179586476925286766559; int set_error(const char *msg) { @@ -50,6 +54,86 @@ namespace } } + int sync_poly_limb_streams(const GpuPoly *poly, const char *context) + { + if (!poly || !poly->ctx || !poly->poly) + { + return set_error(context); + } + const int level = poly->level; + if (level < 0) + { + return set_error(context); + } + auto &limb_map = poly->ctx->ctx->limbGPUid; + if (limb_map.size() < static_cast(level + 1)) + { + return set_error(context); + } + for (int limb = 0; limb <= level; ++limb) + { + const dim3 limb_id = limb_map[static_cast(limb)]; + if (limb_id.x >= poly->poly->GPU.size()) + { + return set_error(context); + } + const auto &partition = poly->poly->GPU[limb_id.x]; + if (limb_id.y >= partition.limb.size()) + { + return set_error(context); + } + + cudaError_t err = cudaSetDevice(partition.device); + if (err != cudaSuccess) + { + return set_error(err); + } + + const auto &limb_impl = partition.limb[limb_id.y]; + cudaStream_t stream = nullptr; + if (limb_impl.index() == FIDESlib::U64) + { + stream = std::get(limb_impl).stream.ptr; + } + else if (limb_impl.index() == FIDESlib::U32) + { + stream = std::get(limb_impl).stream.ptr; + } + else + { + return set_error(context); + } + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) + { + return set_error(err); + } + } + return 0; + } + + int sync_poly_partition_streams(const GpuPoly *poly, const char *context) + { + if (!poly || !poly->poly) + { + return set_error(context); + } + for (const auto &partition : poly->poly->GPU) + { + cudaError_t err = cudaSetDevice(partition.device); + if (err != cudaSuccess) + { + return set_error(err); + } + err = cudaStreamSynchronize(partition.s.ptr); + if (err != cudaSuccess) + { + return set_error(err); + } + } + return 0; + } + uint32_t bit_width_u64(uint64_t v) { if (v == 0) @@ -59,7 +143,7 @@ namespace return static_cast(64 - __builtin_clzll(v)); } - size_t matrix_index(size_t row, size_t col, size_t cols) + __host__ __device__ __forceinline__ size_t matrix_index(size_t row, size_t col, size_t cols) { return row * cols + col; } @@ -150,7 +234,8 @@ namespace size_t poly_count, size_t n, uint32_t shift, - T mask) + T mask, + T out_modulus) { size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; size_t total = poly_count * n; @@ -162,1069 +247,2210 @@ namespace size_t coeff_idx = idx - poly_idx * n; T residue = src[poly_idx][coeff_idx]; T digit = shift >= static_cast(sizeof(T) * 8) ? 0 : ((residue >> shift) & mask); + if (out_modulus != 0 && digit >= out_modulus) + { + digit %= out_modulus; + } dst[poly_idx][coeff_idx] = digit; } - template - __global__ void block_mul_kernel( - const T **lhs, - const T **rhs, - T **out, - size_t poly_count, - size_t n, - T modulus) + using gpu_chacha::DeviceChaChaRng; + using gpu_chacha::rng_init; + using gpu_chacha::rng_next_u64; + + __device__ __forceinline__ double uniform_open01(DeviceChaChaRng &rng) { - size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; - size_t total = poly_count * n; - if (idx >= total) - { - return; - } - size_t poly_idx = idx / n; - size_t coeff_idx = idx - poly_idx * n; - T a = lhs[poly_idx][coeff_idx]; - T b = rhs[poly_idx][coeff_idx]; - if constexpr (std::is_same_v) + constexpr double kScale = 1.0 / 9007199254740992.0; // 2^53 + double u = static_cast(rng_next_u64(rng) >> 11U) * kScale; + if (u <= 0.0) { - out[poly_idx][coeff_idx] = mul_mod_u64(a, b, modulus); + u = kScale; } - else + else if (u >= 1.0) { - out[poly_idx][coeff_idx] = mul_mod_u32(a, b, modulus); + u = 1.0 - kScale; } + return u; } - template - __global__ void block_matmul_kernel( - const T **lhs, - const T **rhs, - T **out, - size_t rows, - size_t inner, - size_t cols, - size_t n, - T modulus) + __device__ __forceinline__ double sample_standard_normal(DeviceChaChaRng &rng) { - __shared__ T lhs_tile[kMatmulTileM][kMatmulTileK]; - __shared__ T rhs_tile[kMatmulTileK][kMatmulTileN]; + double u1 = uniform_open01(rng); + double u2 = uniform_open01(rng); + double r = sqrt(-2.0 * log(u1)); + double theta = kTwoPi * u2; + return r * cos(theta); + } - const size_t row_base = static_cast(blockIdx.y) * kMatmulTileM; - const size_t col_base = static_cast(blockIdx.x) * kMatmulTileN; - const size_t row = row_base + threadIdx.y; - const size_t col = col_base + threadIdx.x; - const size_t coeff_idx = static_cast(blockIdx.z); - if (coeff_idx >= n) + __device__ __forceinline__ bool karney_algorithm_h(DeviceChaChaRng &rng) + { + double h_a = uniform_open01(rng); + if (!(h_a < 0.5)) { - return; + return true; } - - const int tid = static_cast(threadIdx.y) * blockDim.x + threadIdx.x; - const int threads = blockDim.x * blockDim.y; - - T acc = 0; - for (size_t k0 = 0; k0 < inner; k0 += kMatmulTileK) + for (;;) { - for (int i = tid; i < kMatmulTileM * kMatmulTileK; i += threads) + double h_b = uniform_open01(rng); + if (!(h_b < h_a)) { - const int r = i / kMatmulTileK; - const int k = i - r * kMatmulTileK; - const size_t lhs_row = row_base + static_cast(r); - const size_t lhs_k = k0 + static_cast(k); - T val = 0; - if (lhs_row < rows && lhs_k < inner) - { - const T *lhs_poly = lhs[lhs_row * inner + lhs_k]; - val = lhs_poly[coeff_idx]; - } - lhs_tile[r][k] = val; + return false; } - for (int i = tid; i < kMatmulTileK * kMatmulTileN; i += threads) + h_a = uniform_open01(rng); + if (!(h_a < h_b)) { - const int k = i / kMatmulTileN; - const int c = i - k * kMatmulTileN; - const size_t rhs_k = k0 + static_cast(k); - const size_t rhs_col = col_base + static_cast(c); - T val = 0; - if (rhs_k < inner && rhs_col < cols) - { - const T *rhs_poly = rhs[rhs_k * cols + rhs_col]; - val = rhs_poly[coeff_idx]; - } - rhs_tile[k][c] = val; + return true; } - __syncthreads(); + } + } - if (row < rows && col < cols) + __device__ __forceinline__ int32_t karney_algorithm_g(DeviceChaChaRng &rng) + { + int32_t n = 0; + while (karney_algorithm_h(rng)) + { + ++n; + if (n > 1024) { - for (int kk = 0; kk < kMatmulTileK; ++kk) - { - T prod; - if constexpr (std::is_same_v) - { - prod = mul_mod_u64(lhs_tile[threadIdx.y][kk], rhs_tile[kk][threadIdx.x], modulus); - acc = add_mod_u64(acc, prod, modulus); - } - else - { - prod = mul_mod_u32(lhs_tile[threadIdx.y][kk], rhs_tile[kk][threadIdx.x], modulus); - acc = add_mod_u32(acc, prod, modulus); - } - } + break; } - __syncthreads(); } + return n; + } - if (row < rows && col < cols) + __device__ __forceinline__ bool karney_algorithm_p(DeviceChaChaRng &rng, int32_t n) + { + while (n-- && karney_algorithm_h(rng)) { - out[row * cols + col][coeff_idx] = acc; } + return n < 0; } - template - int launch_block_kernel( - const std::vector &out_ptrs, - const std::vector &lhs_ptrs, - const std::vector &rhs_ptrs, - size_t n, - T modulus, - BlockOp op, - cudaStream_t stream) + __device__ __forceinline__ bool karney_algorithm_b(DeviceChaChaRng &rng, int32_t k, double x) { - const size_t count = out_ptrs.size(); - if (count == 0 || n == 0) + double y = x; + int32_t n = 0; + double m = static_cast(2 * k + 2); + for (;; ++n) { - return 0; + double z = uniform_open01(rng); + if (!(z < y)) + { + break; + } + double r = uniform_open01(rng); + if (!(r < (2.0 * static_cast(k) + x) / m)) + { + break; + } + y = z; + if (n > 4096) + { + break; + } } + return (n % 2) == 0; + } - T **d_out = nullptr; - const T **d_lhs = nullptr; - const T **d_rhs = nullptr; - const size_t bytes = count * sizeof(T *); - - cudaError_t err = cudaMalloc(&d_out, bytes); - if (err != cudaSuccess) + __device__ __forceinline__ int64_t sample_integer_karney(DeviceChaChaRng &rng, double mean, double stddev) + { + if (!(stddev > 0.0) || !isfinite(mean) || !isfinite(stddev)) { - return set_error(err); + return static_cast(llround(mean)); } - err = cudaMalloc(&d_lhs, bytes); - if (err != cudaSuccess) + + int64_t ceil_std = static_cast(ceil(stddev)); + if (ceil_std <= 0) { - cudaFree(d_out); - return set_error(err); + return static_cast(llround(mean)); } - err = cudaMalloc(&d_rhs, bytes); - if (err != cudaSuccess) + + for (int iter = 0; iter < 1 << 16; ++iter) { - cudaFree(d_out); - cudaFree(d_lhs); - return set_error(err); + int32_t k = karney_algorithm_g(rng); + if (!karney_algorithm_p(rng, k * (k - 1))) + { + continue; + } + + int64_t s = (rng_next_u64(rng) & 1ULL) ? 1 : -1; + double di0 = stddev * static_cast(k) + static_cast(s) * mean; + int64_t i0 = static_cast(ceil(di0)); + double x0 = (static_cast(i0) - di0) / stddev; + int64_t j = static_cast(rng_next_u64(rng) % static_cast(ceil_std)); + double x = x0 + static_cast(j) / stddev; + + if (!(x < 1.0) || (x == 0.0 && s < 0 && k == 0)) + { + continue; + } + + int32_t h = k + 1; + while (h-- > 0 && karney_algorithm_b(rng, k, x)) + { + } + if (h >= 0) + { + continue; + } + + return s * (i0 + j); } - err = cudaMemcpyAsync(d_out, out_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); - if (err != cudaSuccess) + // Fallback in case the rejection loop takes too long. + return static_cast(llround(mean + stddev * sample_standard_normal(rng))); + } + + __device__ __forceinline__ void get_base_digits_u64( + uint64_t value, + uint64_t base, + uint32_t digits, + int64_t *out_digits) + { + for (uint32_t i = 0; i < digits; ++i) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + out_digits[i] = static_cast(value % base); + value /= base; } - err = cudaMemcpyAsync(d_lhs, lhs_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); - if (err != cudaSuccess) + } + + __device__ __forceinline__ uint64_t signed_mod_i64(int64_t value, uint64_t modulus) + { + if (modulus == 0) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + return 0; } - err = cudaMemcpyAsync(d_rhs, rhs_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); - if (err != cudaSuccess) + if (value >= 0) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + return static_cast(value) % modulus; } + uint64_t magnitude = static_cast(-(value + 1)) + 1; + uint64_t rem = magnitude % modulus; + return rem == 0 ? 0 : (modulus - rem); + } - const int threads = 256; - const size_t total = count * n; - const int blocks = static_cast((total + threads - 1) / threads); + __device__ __forceinline__ uint64_t sample_uniform_mod(DeviceChaChaRng &rng, uint64_t modulus) + { + if (modulus == 0) + { + return 0; + } + constexpr uint64_t kU64Max = ~uint64_t{0}; + const uint64_t threshold = kU64Max - (kU64Max % modulus); + for (;;) + { + uint64_t x = rng_next_u64(rng); + if (x < threshold) + { + return x % modulus; + } + } + } - switch (op) + __device__ __forceinline__ int64_t centered_residue_i64(uint64_t value, uint64_t modulus) + { + if (modulus == 0) { - case BlockOp::Add: - block_add_kernel<<>>(d_lhs, d_rhs, d_out, count, n, modulus); - break; - case BlockOp::Sub: - block_sub_kernel<<>>(d_lhs, d_rhs, d_out, count, n, modulus); - break; - case BlockOp::Mul: - block_mul_kernel<<>>(d_lhs, d_rhs, d_out, count, n, modulus); - break; + return 0; + } + uint64_t reduced = value % modulus; + uint64_t half = modulus >> 1; + if (reduced <= half) + { + return static_cast(reduced); } + uint64_t neg = modulus - reduced; + return -static_cast(neg); + } - err = cudaGetLastError(); - if (err != cudaSuccess) + __global__ void matrix_sample_distribution_kernel( + uint64_t **dst, + size_t poly_count, + size_t n, + uint64_t modulus, + int dist_type, + double sigma, + uint32_t limb_idx, + uint64_t seed) + { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t total = poly_count * n; + if (idx >= total) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + return; } + size_t poly_idx = idx / n; + size_t coeff_idx = idx - poly_idx * n; - err = cudaStreamSynchronize(stream); - if (err != cudaSuccess) + uint64_t sample = 0; + if (dist_type == GPU_MATRIX_DIST_UNIFORM) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + DeviceChaChaRng rng; + rng_init( + rng, + seed, + static_cast(poly_idx + 1), + static_cast(coeff_idx + 1), + static_cast(limb_idx + 1), + 0x6f70656e66686531ULL); + sample = sample_uniform_mod(rng, modulus); + } + else if (dist_type == GPU_MATRIX_DIST_GAUSS) + { + DeviceChaChaRng rng; + rng_init( + rng, + seed, + static_cast(poly_idx + 1), + static_cast(coeff_idx + 1), + 0, + 0x6f70656e66686532ULL); + int64_t z = sample_integer_karney(rng, 0.0, sigma); + sample = signed_mod_i64(z, modulus); + } + else if (dist_type == GPU_MATRIX_DIST_BIT) + { + DeviceChaChaRng rng; + rng_init( + rng, + seed, + static_cast(poly_idx + 1), + static_cast(coeff_idx + 1), + 0, + 0x6f70656e66686533ULL); + sample = (rng_next_u64(rng) & 1ULL) % modulus; + } + else if (dist_type == GPU_MATRIX_DIST_TERNARY) + { + DeviceChaChaRng rng; + rng_init( + rng, + seed, + static_cast(poly_idx + 1), + static_cast(coeff_idx + 1), + 0, + 0x6f70656e66686534ULL); + uint64_t pick = rng_next_u64(rng) % 3ULL; + int64_t z = pick == 0 ? 0 : (pick == 1 ? 1 : -1); + sample = signed_mod_i64(z, modulus); } - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return 0; + dst[poly_idx][coeff_idx] = sample; } - template - int launch_block_matmul_kernel( - const std::vector &out_ptrs, - const std::vector &lhs_ptrs, - const std::vector &rhs_ptrs, - size_t rows, - size_t inner, - size_t cols, - size_t n, - T modulus, - cudaStream_t stream, - double *out_kernel_ms) + __device__ __forceinline__ uint64_t pow_mod_u64(uint64_t base, uint32_t exp, uint64_t modulus) { - const size_t out_count = rows * cols; - const size_t lhs_count = rows * inner; - const size_t rhs_count = inner * cols; - if (out_count == 0 || n == 0) + if (modulus == 0) { return 0; } - if (out_ptrs.size() != out_count || lhs_ptrs.size() != lhs_count || rhs_ptrs.size() != rhs_count) + uint64_t result = 1ULL % modulus; + uint64_t cur = base % modulus; + uint32_t e = exp; + while (e > 0) { - return set_error("unexpected pointer counts in gpu_block_mul"); + if (e & 1U) + { + result = static_cast((static_cast(result) * cur) % modulus); + } + e >>= 1U; + if (e > 0) + { + cur = static_cast((static_cast(cur) * cur) % modulus); + } } + return result; + } - T **d_out = nullptr; - const T **d_lhs = nullptr; - const T **d_rhs = nullptr; - const size_t out_bytes = out_count * sizeof(T *); - const size_t lhs_bytes = lhs_count * sizeof(T *); - const size_t rhs_bytes = rhs_count * sizeof(T *); - - cudaError_t err = cudaMalloc(&d_out, out_bytes); - if (err != cudaSuccess) - { - return set_error(err); - } - err = cudaMalloc(&d_lhs, lhs_bytes); - if (err != cudaSuccess) + __global__ void matrix_fill_gadget_kernel( + uint64_t **dst, + size_t poly_count, + size_t n, + uint64_t modulus, + size_t rows, + size_t cols, + size_t log_base_q, + uint32_t digits_per_tower, + uint32_t limb_idx, + uint32_t base_bits) + { + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t total = poly_count * n; + if (idx >= total) { - cudaFree(d_out); - return set_error(err); + return; } - err = cudaMalloc(&d_rhs, rhs_bytes); - if (err != cudaSuccess) + size_t poly_idx = idx / n; + size_t coeff_idx = idx - poly_idx * n; + + uint64_t value = 0; + if (coeff_idx == 0 && rows > 0 && cols > 0 && log_base_q > 0) { - cudaFree(d_out); - cudaFree(d_lhs); - return set_error(err); + size_t row = poly_idx / cols; + size_t col = poly_idx - row * cols; + size_t block_start = row * log_base_q; + if (col >= block_start && col < block_start + log_base_q) + { + size_t local = col - block_start; + uint32_t tower = static_cast(local / static_cast(digits_per_tower)); + uint32_t digit = static_cast(local % static_cast(digits_per_tower)); + if (tower == limb_idx) + { + uint64_t base = uint64_t{1} << base_bits; + value = pow_mod_u64(base, digit, modulus); + } + } } + dst[poly_idx][coeff_idx] = value; + } - err = cudaMemcpyAsync(d_out, out_ptrs.data(), out_bytes, cudaMemcpyHostToDevice, stream); - if (err != cudaSuccess) + __global__ void matrix_sample_p1_full_kernel( + const uint64_t **a_entries, + const uint64_t **b_entries, + const uint64_t **d_entries, + const uint64_t **tp2_entries, + uint64_t **out_entries, + size_t d, + size_t cols, + size_t n, + size_t sample_start, + size_t sample_count, + double *cov_workspace, + double *mean_workspace, + double *col_workspace, + int64_t *sampled_workspace, + uint64_t modulus, + double sigma, + double s, + double dgg_stddev, + uint32_t limb_idx, + uint64_t seed) + { + const size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= sample_count) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + return; } - err = cudaMemcpyAsync(d_lhs, lhs_ptrs.data(), lhs_bytes, cudaMemcpyHostToDevice, stream); - if (err != cudaSuccess) + + const size_t m = d * 2; + if (m == 0) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + return; } - err = cudaMemcpyAsync(d_rhs, rhs_ptrs.data(), rhs_bytes, cudaMemcpyHostToDevice, stream); - if (err != cudaSuccess) + + const size_t sample_idx = sample_start + idx; + const size_t col_idx = sample_idx / n; + const size_t coeff_idx = sample_idx - col_idx * n; + const size_t cov_stride = m * m; + const size_t vec_stride = m; + double *cov = cov_workspace + idx * cov_stride; + double *mean = mean_workspace + idx * vec_stride; + double *col_buf = col_workspace + idx * vec_stride; + int64_t *sampled = sampled_workspace + idx * vec_stride; + + DeviceChaChaRng rng; + rng_init( + rng, + seed, + static_cast(col_idx + 1), + static_cast(coeff_idx + 1), + static_cast(limb_idx + 1), + 0x7065727475726231ULL); + + const double sigma2 = sigma * sigma; + const double s2 = s * s; + const double denom = s2 - sigma2; + if (!(denom > 0.0)) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + return; } + const double c_scale = -sigma2 / denom; + const double fallback_var = dgg_stddev * dgg_stddev; + const double eps = 1e-9; - const dim3 threads(kMatmulTileN, kMatmulTileM); - const dim3 blocks( - static_cast((cols + kMatmulTileN - 1) / kMatmulTileN), - static_cast((rows + kMatmulTileM - 1) / kMatmulTileM), - static_cast(n)); - - cudaEvent_t start = nullptr; - cudaEvent_t stop = nullptr; - if (out_kernel_ms) + for (size_t i = 0; i < d; ++i) { - err = cudaEventCreate(&start); - if (err != cudaSuccess) - { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); - } - err = cudaEventCreate(&stop); - if (err != cudaSuccess) - { - cudaEventDestroy(start); - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); - } - err = cudaEventRecord(start, stream); - if (err != cudaSuccess) + for (size_t j = 0; j < d; ++j) { - cudaEventDestroy(start); - cudaEventDestroy(stop); - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + const size_t ij = matrix_index(i, j, d); + const size_t ji = matrix_index(j, i, d); + const double a_ij = static_cast(centered_residue_i64(a_entries[ij][coeff_idx], modulus)); + const double d_ij = static_cast(centered_residue_i64(d_entries[ij][coeff_idx], modulus)); + const double b_ij = static_cast(centered_residue_i64(b_entries[ij][coeff_idx], modulus)); + const double b_ji = static_cast(centered_residue_i64(b_entries[ji][coeff_idx], modulus)); + + const double af = -sigma2 * a_ij + (i == j ? s2 : 0.0); + const double df = -sigma2 * d_ij + (i == j ? s2 : 0.0); + const double bf = -sigma2 * b_ij; + const double bt = -sigma2 * b_ji; + + cov[matrix_index(i, j, m)] = af; + cov[matrix_index(i + d, j + d, m)] = df; + cov[matrix_index(i, j + d, m)] = bf; + cov[matrix_index(i + d, j, m)] = bt; } } - block_matmul_kernel<<>>(d_lhs, d_rhs, d_out, rows, inner, cols, n, modulus); + for (size_t row = 0; row < m; ++row) + { + const size_t tp_idx = matrix_index(row, col_idx, cols); + const double c_centered = static_cast(centered_residue_i64(tp2_entries[tp_idx][coeff_idx], modulus)); + mean[row] = c_scale * c_centered; + } - err = cudaGetLastError(); - if (err != cudaSuccess) + for (int t = static_cast(m) - 1; t >= 0; --t) { - if (start) + const size_t tt = static_cast(t); + double var = cov[matrix_index(tt, tt, m)]; + if (!(var > eps)) { - cudaEventDestroy(start); - cudaEventDestroy(stop); + var = fallback_var; } - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); - } + const double mu = mean[tt]; + const int64_t z = sample_integer_karney(rng, mu, sqrt(var)); + sampled[tt] = z; - if (out_kernel_ms) - { - err = cudaEventRecord(stop, stream); - if (err != cudaSuccess) + if (t == 0) { - cudaEventDestroy(start); - cudaEventDestroy(stop); - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + break; } - err = cudaEventSynchronize(stop); - if (err != cudaSuccess) + + const double delta = static_cast(z) - mu; + for (int i = 0; i < t; ++i) { - cudaEventDestroy(start); - cudaEventDestroy(stop); - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + col_buf[static_cast(i)] = + cov[matrix_index(static_cast(i), tt, m)]; } - float kernel_ms = 0.0f; - err = cudaEventElapsedTime(&kernel_ms, start, stop); - if (err != cudaSuccess) + + for (int i = 0; i < t; ++i) { - cudaEventDestroy(start); - cudaEventDestroy(stop); - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + mean[static_cast(i)] += + (col_buf[static_cast(i)] / var) * delta; } - *out_kernel_ms += static_cast(kernel_ms); - cudaEventDestroy(start); - cudaEventDestroy(stop); - } - else - { - err = cudaStreamSynchronize(stream); - if (err != cudaSuccess) + + for (int i = 0; i < t; ++i) { - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return set_error(err); + for (int j = 0; j <= i; ++j) + { + double updated = cov[matrix_index(static_cast(i), static_cast(j), m)] - + (col_buf[static_cast(i)] * col_buf[static_cast(j)] / var); + cov[matrix_index(static_cast(i), static_cast(j), m)] = updated; + cov[matrix_index(static_cast(j), static_cast(i), m)] = updated; + } } } - cudaFree(d_out); - cudaFree(d_lhs); - cudaFree(d_rhs); - return 0; + for (size_t row = 0; row < m; ++row) + { + const size_t out_idx = matrix_index(row, col_idx, cols); + out_entries[out_idx][coeff_idx] = signed_mod_i64(sampled[row], modulus); + } } - template - int launch_decompose_kernel( - const std::vector &src_ptrs, - const std::vector &dst_ptrs, + __global__ void matrix_gauss_samp_gq_arb_base_kernel( + const uint64_t **src, + uint64_t **dst, + size_t poly_count, size_t n, - uint32_t shift, - T mask, - cudaStream_t stream) + uint64_t tower_modulus, + uint32_t base_bits, + uint32_t digits_per_tower, + uint32_t digit_idx, + double c, + uint32_t tower_idx, + uint64_t seed, + uint64_t out_modulus) { - const size_t count = src_ptrs.size(); - if (count == 0 || n == 0) + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t total = poly_count * n; + if (idx >= total) { - return 0; + return; } - if (dst_ptrs.size() != count) + + if (digits_per_tower == 0 || digits_per_tower > kGaussMaxDigits || base_bits == 0 || base_bits >= 63) { - return set_error("unexpected pointer counts in matrix_decompose_kernel"); + return; } - const T **d_src = nullptr; - T **d_dst = nullptr; - const size_t bytes = count * sizeof(T *); + size_t poly_idx = idx / n; + size_t coeff_idx = idx - poly_idx * n; + uint64_t value = src[poly_idx][coeff_idx]; + if (tower_modulus != 0) + { + value %= tower_modulus; + } - cudaError_t err = cudaMalloc(&d_src, bytes); - if (err != cudaSuccess) + uint64_t base = uint64_t{1} << base_bits; + double base_f = static_cast(base); + double sigma = c / (base_f + 1.0); + + int64_t m_digits[kGaussMaxDigits]; + int64_t v_digits[kGaussMaxDigits]; + double l[kGaussMaxDigits]; + double h[kGaussMaxDigits]; + double c_vec[kGaussMaxDigits]; + double p[kGaussMaxDigits]; + double a[kGaussMaxDigits]; + double zf[kGaussMaxDigits]; + int64_t z[kGaussMaxDigits]; + + get_base_digits_u64(tower_modulus, base, digits_per_tower, m_digits); + get_base_digits_u64(value, base, digits_per_tower, v_digits); + + const double kf = static_cast(digits_per_tower); + l[0] = sqrt(base_f * (1.0 + 1.0 / kf) + 1.0); + for (uint32_t i = 1; i < digits_per_tower; ++i) { - return set_error(err); + l[i] = sqrt(base_f * (1.0 + 1.0 / (kf - static_cast(i)))); } - err = cudaMalloc(&d_dst, bytes); - if (err != cudaSuccess) + + h[0] = 0.0; + for (uint32_t i = 1; i < digits_per_tower; ++i) { - cudaFree(d_src); - return set_error(err); + h[i] = sqrt(base_f * (1.0 - 1.0 / (kf - static_cast(i - 1)))); } - err = cudaMemcpyAsync(d_src, src_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); - if (err != cudaSuccess) + c_vec[0] = static_cast(m_digits[0]) / base_f; + for (uint32_t i = 1; i < digits_per_tower; ++i) { - cudaFree(d_src); - cudaFree(d_dst); - return set_error(err); + c_vec[i] = (c_vec[i - 1] + static_cast(m_digits[i])) / base_f; } - err = cudaMemcpyAsync(d_dst, dst_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); - if (err != cudaSuccess) + + DeviceChaChaRng rng; + rng_init( + rng, + seed, + static_cast(tower_idx + 1), + static_cast(poly_idx + 1), + static_cast(coeff_idx + 1), + 0x6761646765746731ULL); + + for (uint32_t i = 0; i < digits_per_tower; ++i) { - cudaFree(d_src); - cudaFree(d_dst); - return set_error(err); + zf[i] = sigma * sample_standard_normal(rng); } + for (uint32_t i = 0; i + 1 < digits_per_tower; ++i) + { + p[i] = l[i] * zf[i] + h[i + 1] * zf[i + 1]; + } + p[digits_per_tower - 1] = h[digits_per_tower - 1] * zf[digits_per_tower - 1]; - const int threads = 256; - const size_t total = count * n; - const int blocks = static_cast((total + threads - 1) / threads); + a[0] = (static_cast(v_digits[0]) - p[0]) / base_f; + for (uint32_t t = 1; t < digits_per_tower; ++t) + { + a[t] = (a[t - 1] + static_cast(v_digits[t]) - p[t]) / base_f; + } - matrix_decompose_kernel<<>>(d_src, d_dst, count, n, shift, mask); - err = cudaGetLastError(); - if (err != cudaSuccess) + const uint32_t last = digits_per_tower - 1; + z[last] = sample_integer_karney(rng, -a[last] / c_vec[last], sigma / c_vec[last]); + for (uint32_t i = 0; i < digits_per_tower; ++i) { - cudaFree(d_src); - cudaFree(d_dst); - return set_error(err); + a[i] += static_cast(z[last]) * c_vec[i]; + } + for (uint32_t i = 0; i < last; ++i) + { + z[i] = sample_integer_karney(rng, -a[i], sigma); } - err = cudaStreamSynchronize(stream); - if (err != cudaSuccess) + int64_t out_digit = 0; + if (digits_per_tower == 1) { - cudaFree(d_src); - cudaFree(d_dst); - return set_error(err); + out_digit = static_cast(base) * z[0] + m_digits[0] * z[0] + v_digits[0]; + } + else if (digit_idx == 0) + { + out_digit = static_cast(base) * z[0] + m_digits[0] * z[last] + v_digits[0]; + } + else if (digit_idx < last) + { + out_digit = static_cast(base) * z[digit_idx] - z[digit_idx - 1] + + m_digits[digit_idx] * z[last] + v_digits[digit_idx]; + } + else + { + out_digit = m_digits[last] * z[last] - z[last - 1] + v_digits[last]; } - cudaFree(d_src); - cudaFree(d_dst); - return 0; + dst[poly_idx][coeff_idx] = signed_mod_i64(out_digit, out_modulus); } template - int launch_for_limb( - GpuPoly *const *out, - const GpuPoly *const *lhs, - const GpuPoly *const *rhs, - size_t count, + __global__ void block_mul_kernel( + const T **lhs, + const T **rhs, + T **out, + size_t poly_count, size_t n, - int limb, - const dim3 &limb_id, - BlockOp op) + T modulus) { - std::vector out_ptrs; - std::vector lhs_ptrs; - std::vector rhs_ptrs; - out_ptrs.reserve(count); - lhs_ptrs.reserve(count); - rhs_ptrs.reserve(count); - - cudaStream_t stream = nullptr; - for (size_t i = 0; i < count; ++i) + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + size_t total = poly_count * n; + if (idx >= total) { - auto &lhs_partition = lhs[i]->poly->GPU[limb_id.x]; - auto &rhs_partition = rhs[i]->poly->GPU[limb_id.x]; - auto &out_partition = out[i]->poly->GPU[limb_id.x]; - if (limb_id.y >= lhs_partition.limb.size() || limb_id.y >= rhs_partition.limb.size() || - limb_id.y >= out_partition.limb.size()) - { - return set_error("unexpected limb index in gpu_block_op"); - } - auto &lhs_limb = lhs_partition.limb[limb_id.y]; - auto &rhs_limb = rhs_partition.limb[limb_id.y]; - auto &out_limb = out_partition.limb[limb_id.y]; - if (lhs_limb.index() != rhs_limb.index() || lhs_limb.index() != out_limb.index()) - { - return set_error("mixed limb types in gpu_block_op"); - } - - if constexpr (std::is_same_v) - { - auto &lhs_u = std::get(lhs_limb); - auto &rhs_u = std::get(rhs_limb); - auto &out_u = std::get(out_limb); - if (!stream) - { - stream = lhs_u.stream.ptr; - } - lhs_ptrs.push_back(lhs_u.v.data); - rhs_ptrs.push_back(rhs_u.v.data); - out_ptrs.push_back(out_u.v.data); - } - else - { - auto &lhs_u = std::get(lhs_limb); - auto &rhs_u = std::get(rhs_limb); - auto &out_u = std::get(out_limb); - if (!stream) - { - stream = lhs_u.stream.ptr; - } - lhs_ptrs.push_back(lhs_u.v.data); - rhs_ptrs.push_back(rhs_u.v.data); - out_ptrs.push_back(out_u.v.data); - } + return; + } + size_t poly_idx = idx / n; + size_t coeff_idx = idx - poly_idx * n; + T a = lhs[poly_idx][coeff_idx]; + T b = rhs[poly_idx][coeff_idx]; + if constexpr (std::is_same_v) + { + out[poly_idx][coeff_idx] = mul_mod_u64(a, b, modulus); + } + else + { + out[poly_idx][coeff_idx] = mul_mod_u32(a, b, modulus); } - - const uint64_t modulus64 = lhs[0]->ctx->moduli[static_cast(limb)]; - const T modulus = static_cast(modulus64); - return launch_block_kernel(out_ptrs, lhs_ptrs, rhs_ptrs, n, modulus, op, stream); } template - int launch_for_limb_matmul( - GpuPoly *const *out, - const GpuPoly *const *lhs, - const GpuPoly *const *rhs, + __global__ void block_matmul_kernel( + const T **lhs, + const T **rhs, + T **out, size_t rows, size_t inner, size_t cols, size_t n, - int limb, - const dim3 &limb_id, - double *out_kernel_ms) + T modulus) { - const size_t out_count = rows * cols; - const size_t lhs_count = rows * inner; - const size_t rhs_count = inner * cols; + __shared__ T lhs_tile[kMatmulTileM][kMatmulTileK]; + __shared__ T rhs_tile[kMatmulTileK][kMatmulTileN]; - std::vector out_ptrs; - std::vector lhs_ptrs; - std::vector rhs_ptrs; - out_ptrs.reserve(out_count); - lhs_ptrs.reserve(lhs_count); - rhs_ptrs.reserve(rhs_count); + const size_t row_base = static_cast(blockIdx.y) * kMatmulTileM; + const size_t col_base = static_cast(blockIdx.x) * kMatmulTileN; + const size_t row = row_base + threadIdx.y; + const size_t col = col_base + threadIdx.x; + const size_t coeff_idx = static_cast(blockIdx.z); + if (coeff_idx >= n) + { + return; + } - cudaStream_t stream = nullptr; - int limb_index = -1; + const int tid = static_cast(threadIdx.y) * blockDim.x + threadIdx.x; + const int threads = blockDim.x * blockDim.y; - for (size_t i = 0; i < lhs_count; ++i) + T acc = 0; + for (size_t k0 = 0; k0 < inner; k0 += kMatmulTileK) { - auto &lhs_partition = lhs[i]->poly->GPU[limb_id.x]; - if (limb_id.y >= lhs_partition.limb.size()) - { - return set_error("unexpected limb index in gpu_block_mul"); - } - auto &lhs_limb = lhs_partition.limb[limb_id.y]; - if constexpr (std::is_same_v) + for (int i = tid; i < kMatmulTileM * kMatmulTileK; i += threads) { - if (lhs_limb.index() != FIDESlib::U64) - { - return set_error("mixed limb types in gpu_block_mul"); - } - auto &lhs_u = std::get(lhs_limb); - if (!stream) + const int r = i / kMatmulTileK; + const int k = i - r * kMatmulTileK; + const size_t lhs_row = row_base + static_cast(r); + const size_t lhs_k = k0 + static_cast(k); + T val = 0; + if (lhs_row < rows && lhs_k < inner) { - stream = lhs_u.stream.ptr; + const T *lhs_poly = lhs[lhs_row * inner + lhs_k]; + val = lhs_poly[coeff_idx]; } - lhs_ptrs.push_back(lhs_u.v.data); + lhs_tile[r][k] = val; } - else + for (int i = tid; i < kMatmulTileK * kMatmulTileN; i += threads) { - if (lhs_limb.index() != FIDESlib::U32) - { - return set_error("mixed limb types in gpu_block_mul"); - } - auto &lhs_u = std::get(lhs_limb); - if (!stream) + const int k = i / kMatmulTileN; + const int c = i - k * kMatmulTileN; + const size_t rhs_k = k0 + static_cast(k); + const size_t rhs_col = col_base + static_cast(c); + T val = 0; + if (rhs_k < inner && rhs_col < cols) { - stream = lhs_u.stream.ptr; + const T *rhs_poly = rhs[rhs_k * cols + rhs_col]; + val = rhs_poly[coeff_idx]; } - lhs_ptrs.push_back(lhs_u.v.data); + rhs_tile[k][c] = val; } - if (limb_index == -1) + __syncthreads(); + + if (row < rows && col < cols) { - limb_index = lhs_limb.index(); + for (int kk = 0; kk < kMatmulTileK; ++kk) + { + T prod; + if constexpr (std::is_same_v) + { + prod = mul_mod_u64(lhs_tile[threadIdx.y][kk], rhs_tile[kk][threadIdx.x], modulus); + acc = add_mod_u64(acc, prod, modulus); + } + else + { + prod = mul_mod_u32(lhs_tile[threadIdx.y][kk], rhs_tile[kk][threadIdx.x], modulus); + acc = add_mod_u32(acc, prod, modulus); + } + } } + __syncthreads(); } - for (size_t i = 0; i < rhs_count; ++i) + if (row < rows && col < cols) { - auto &rhs_partition = rhs[i]->poly->GPU[limb_id.x]; - if (limb_id.y >= rhs_partition.limb.size()) - { - return set_error("unexpected limb index in gpu_block_mul"); - } - auto &rhs_limb = rhs_partition.limb[limb_id.y]; - if (rhs_limb.index() != limb_index) - { - return set_error("mixed limb types in gpu_block_mul"); - } - - if constexpr (std::is_same_v) - { - auto &rhs_u = std::get(rhs_limb); - rhs_ptrs.push_back(rhs_u.v.data); - } - else - { - auto &rhs_u = std::get(rhs_limb); - rhs_ptrs.push_back(rhs_u.v.data); - } + out[row * cols + col][coeff_idx] = acc; } + } - for (size_t i = 0; i < out_count; ++i) + template + int launch_block_kernel( + const std::vector &out_ptrs, + const std::vector &lhs_ptrs, + const std::vector &rhs_ptrs, + size_t n, + T modulus, + BlockOp op, + cudaStream_t stream) + { + const size_t count = out_ptrs.size(); + if (count == 0 || n == 0) { - auto &out_partition = out[i]->poly->GPU[limb_id.x]; - if (limb_id.y >= out_partition.limb.size()) - { - return set_error("unexpected limb index in gpu_block_mul"); - } - auto &out_limb = out_partition.limb[limb_id.y]; - if (out_limb.index() != limb_index) - { - return set_error("mixed limb types in gpu_block_mul"); - } - - if constexpr (std::is_same_v) - { - auto &out_u = std::get(out_limb); - out_ptrs.push_back(out_u.v.data); - } - else - { - auto &out_u = std::get(out_limb); - out_ptrs.push_back(out_u.v.data); - } + return 0; } - const uint64_t modulus64 = lhs[0]->ctx->moduli[static_cast(limb)]; - const T modulus = static_cast(modulus64); - return launch_block_matmul_kernel(out_ptrs, lhs_ptrs, rhs_ptrs, rows, inner, cols, n, modulus, stream, out_kernel_ms); - } + T **d_out = nullptr; + const T **d_lhs = nullptr; + const T **d_rhs = nullptr; + const size_t bytes = count * sizeof(T *); - int gpu_block_op(GpuPoly *const *out, const GpuPoly *const *lhs, const GpuPoly *const *rhs, size_t count, BlockOp op) - { - if (!out || !lhs || !rhs) + cudaError_t err = cudaMalloc(&d_out, bytes); + if (err != cudaSuccess) { - return set_error("invalid gpu_block_op arguments"); + return set_error(err); } - if (count == 0) + err = cudaMalloc(&d_lhs, bytes); + if (err != cudaSuccess) { - return 0; + cudaFree(d_out); + return set_error(err); } - - const GpuPoly *lhs0 = lhs[0]; - if (!lhs0 || !lhs0->ctx) + err = cudaMalloc(&d_rhs, bytes); + if (err != cudaSuccess) { - return set_error("null context in gpu_block_op"); + cudaFree(d_out); + cudaFree(d_lhs); + return set_error(err); } - const GpuContext *ctx = lhs0->ctx; - const int level = lhs0->level; - const PolyFormat format = lhs0->format; - if (op == BlockOp::Mul && format != PolyFormat::Eval) + + err = cudaMemcpyAsync(d_out, out_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) { - return set_error("gpu_block_entrywise_mul requires Eval format"); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - - // for (size_t i = 0; i < count; ++i) - // { - // if (!out[i] || !lhs[i] || !rhs[i]) - // { - // return set_error("null polynomial in gpu_block_op"); - // } - // if (lhs[i]->ctx != ctx || rhs[i]->ctx != ctx || out[i]->ctx != ctx) - // { - // return set_error("mismatched contexts in gpu_block_op"); - // } - // if (lhs[i]->level != level || rhs[i]->level != level || out[i]->level != level) - // { - // return set_error("mismatched levels in gpu_block_op"); - // } - // if (lhs[i]->format != format || rhs[i]->format != format) - // { - // return set_error("mismatched formats in gpu_block_op"); - // } - // } - - if (level < 0) + err = cudaMemcpyAsync(d_lhs, lhs_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) { - return set_error("invalid level in gpu_block_op"); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - - const int N = ctx->N; - if (N <= 0) + err = cudaMemcpyAsync(d_rhs, rhs_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) { - return 0; + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - auto &limb_map = ctx->ctx->limbGPUid; - if (limb_map.size() < static_cast(level + 1)) + const int threads = 256; + const size_t total = count * n; + const int blocks = static_cast((total + threads - 1) / threads); + + switch (op) { - return set_error("unexpected limb mapping size in gpu_block_op"); + case BlockOp::Add: + block_add_kernel<<>>(d_lhs, d_rhs, d_out, count, n, modulus); + break; + case BlockOp::Sub: + block_sub_kernel<<>>(d_lhs, d_rhs, d_out, count, n, modulus); + break; + case BlockOp::Mul: + block_mul_kernel<<>>(d_lhs, d_rhs, d_out, count, n, modulus); + break; } - for (int limb = 0; limb <= level; ++limb) + err = cudaGetLastError(); + if (err != cudaSuccess) { - const dim3 limb_id = limb_map[static_cast(limb)]; - if (limb_id.x >= lhs0->poly->GPU.size()) - { - return set_error("unexpected limb GPU partition in gpu_block_op"); - } - const auto &partition = lhs0->poly->GPU[limb_id.x]; - if (limb_id.y >= partition.limb.size()) - { - return set_error("unexpected limb index in gpu_block_op"); - } - - cudaError_t err = cudaSetDevice(partition.device); - if (err != cudaSuccess) - { - return set_error(err); - } - - const auto &limb_impl = partition.limb[limb_id.y]; - if (limb_impl.index() == FIDESlib::U64) - { - int status = launch_for_limb(out, lhs, rhs, count, static_cast(N), limb, limb_id, op); - if (status != 0) - { - return status; - } - } - else if (limb_impl.index() == FIDESlib::U32) - { - int status = launch_for_limb(out, lhs, rhs, count, static_cast(N), limb, limb_id, op); - if (status != 0) - { - return status; - } - } - else - { - return set_error("unsupported limb type in gpu_block_op"); - } + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - for (size_t i = 0; i < count; ++i) + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) { - out[i]->level = level; - out[i]->format = format; + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } + + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); return 0; } - int gpu_block_matmul( - GpuPoly *const *out, - const GpuPoly *const *lhs, - const GpuPoly *const *rhs, + template + int launch_block_matmul_kernel( + const std::vector &out_ptrs, + const std::vector &lhs_ptrs, + const std::vector &rhs_ptrs, size_t rows, size_t inner, size_t cols, - double *out_kernel_ms = nullptr) + size_t n, + T modulus, + cudaStream_t stream, + double *out_kernel_ms) { - if (!out || !lhs || !rhs) - { - return set_error("invalid gpu_block_mul arguments"); - } - if (out_kernel_ms) - { - *out_kernel_ms = 0.0; - } const size_t out_count = rows * cols; - if (rows == 0 || inner == 0 || cols == 0) + const size_t lhs_count = rows * inner; + const size_t rhs_count = inner * cols; + if (out_count == 0 || n == 0) { return 0; } - if (out_count == 0) + if (out_ptrs.size() != out_count || lhs_ptrs.size() != lhs_count || rhs_ptrs.size() != rhs_count) { - return 0; + return set_error("unexpected pointer counts in gpu_block_mul"); } - const GpuPoly *lhs0 = lhs[0]; - if (!lhs0 || !lhs0->ctx) + T **d_out = nullptr; + const T **d_lhs = nullptr; + const T **d_rhs = nullptr; + const size_t out_bytes = out_count * sizeof(T *); + const size_t lhs_bytes = lhs_count * sizeof(T *); + const size_t rhs_bytes = rhs_count * sizeof(T *); + + cudaError_t err = cudaMalloc(&d_out, out_bytes); + if (err != cudaSuccess) { - return set_error("null context in gpu_block_mul"); + return set_error(err); } - const GpuContext *ctx = lhs0->ctx; - const int level = lhs0->level; - const PolyFormat format = lhs0->format; - if (format != PolyFormat::Eval) + err = cudaMalloc(&d_lhs, lhs_bytes); + if (err != cudaSuccess) { - return set_error("gpu_block_mul requires Eval format"); - } - - if (level < 0) + cudaFree(d_out); + return set_error(err); + } + err = cudaMalloc(&d_rhs, rhs_bytes); + if (err != cudaSuccess) { - return set_error("invalid level in gpu_block_mul"); + cudaFree(d_out); + cudaFree(d_lhs); + return set_error(err); } - const int N = ctx->N; - if (N <= 0) + err = cudaMemcpyAsync(d_out, out_ptrs.data(), out_bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) { - return 0; + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - - auto &limb_map = ctx->ctx->limbGPUid; - if (limb_map.size() < static_cast(level + 1)) + err = cudaMemcpyAsync(d_lhs, lhs_ptrs.data(), lhs_bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) { - return set_error("unexpected limb mapping size in gpu_block_mul"); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); + } + err = cudaMemcpyAsync(d_rhs, rhs_ptrs.data(), rhs_bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - for (int limb = 0; limb <= level; ++limb) + const dim3 threads(kMatmulTileN, kMatmulTileM); + const dim3 blocks( + static_cast((cols + kMatmulTileN - 1) / kMatmulTileN), + static_cast((rows + kMatmulTileM - 1) / kMatmulTileM), + static_cast(n)); + + cudaEvent_t start = nullptr; + cudaEvent_t stop = nullptr; + if (out_kernel_ms) { - const dim3 limb_id = limb_map[static_cast(limb)]; - if (limb_id.x >= lhs0->poly->GPU.size()) + err = cudaEventCreate(&start); + if (err != cudaSuccess) { - return set_error("unexpected limb GPU partition in gpu_block_mul"); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - const auto &partition = lhs0->poly->GPU[limb_id.x]; - if (limb_id.y >= partition.limb.size()) + err = cudaEventCreate(&stop); + if (err != cudaSuccess) { - return set_error("unexpected limb index in gpu_block_mul"); + cudaEventDestroy(start); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - - cudaError_t err = cudaSetDevice(partition.device); + err = cudaEventRecord(start, stream); if (err != cudaSuccess) { + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); return set_error(err); } + } - const auto &limb_impl = partition.limb[limb_id.y]; - if (limb_impl.index() == FIDESlib::U64) + block_matmul_kernel<<>>(d_lhs, d_rhs, d_out, rows, inner, cols, n, modulus); + + err = cudaGetLastError(); + if (err != cudaSuccess) + { + if (start) { - int status = - launch_for_limb_matmul( - out, - lhs, - rhs, - rows, - inner, - cols, - static_cast(N), - limb, - limb_id, - out_kernel_ms); - if (status != 0) - { - return status; - } + cudaEventDestroy(start); + cudaEventDestroy(stop); } - else if (limb_impl.index() == FIDESlib::U32) + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); + } + + if (out_kernel_ms) + { + err = cudaEventRecord(stop, stream); + if (err != cudaSuccess) { - int status = - launch_for_limb_matmul( - out, - lhs, - rhs, - rows, - inner, - cols, - static_cast(N), - limb, - limb_id, - out_kernel_ms); - if (status != 0) - { - return status; - } + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } - else + err = cudaEventSynchronize(stop); + if (err != cudaSuccess) { - return set_error("unsupported limb type in gpu_block_mul"); + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); + } + float kernel_ms = 0.0f; + err = cudaEventElapsedTime(&kernel_ms, start, stop); + if (err != cudaSuccess) + { + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); } + *out_kernel_ms += static_cast(kernel_ms); + cudaEventDestroy(start); + cudaEventDestroy(stop); } - - for (size_t i = 0; i < out_count; ++i) + else { - out[i]->level = level; - out[i]->format = format; + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) + { + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); + return set_error(err); + } } + + cudaFree(d_out); + cudaFree(d_lhs); + cudaFree(d_rhs); return 0; } -} // namespace -extern "C" int gpu_block_add(GpuPoly *const *out, const GpuPoly *const *lhs, const GpuPoly *const *rhs, size_t count) -{ - return gpu_block_op(out, lhs, rhs, count, BlockOp::Add); -} + template + int launch_decompose_kernel( + const std::vector &src_ptrs, + const std::vector &dst_ptrs, + size_t n, + uint32_t shift, + T mask, + T out_modulus, + cudaStream_t stream) + { + const size_t count = src_ptrs.size(); + if (count == 0 || n == 0) + { + return 0; + } + if (dst_ptrs.size() != count) + { + return set_error("unexpected pointer counts in matrix_decompose_kernel"); + } -extern "C" int gpu_block_sub(GpuPoly *const *out, const GpuPoly *const *lhs, const GpuPoly *const *rhs, size_t count) -{ - return gpu_block_op(out, lhs, rhs, count, BlockOp::Sub); -} + const T **d_src = nullptr; + T **d_dst = nullptr; + const size_t bytes = count * sizeof(T *); -extern "C" int gpu_block_entrywise_mul( - GpuPoly *const *out, - const GpuPoly *const *lhs, - const GpuPoly *const *rhs, - size_t count) -{ - return gpu_block_op(out, lhs, rhs, count, BlockOp::Mul); -} + cudaError_t err = cudaMalloc(&d_src, bytes); + if (err != cudaSuccess) + { + return set_error(err); + } + err = cudaMalloc(&d_dst, bytes); + if (err != cudaSuccess) + { + cudaFree(d_src); + return set_error(err); + } -extern "C" int gpu_block_mul( - GpuPoly *const *out, - const GpuPoly *const *lhs, - const GpuPoly *const *rhs, - size_t rows, - size_t inner, - size_t cols) -{ - return gpu_block_matmul(out, lhs, rhs, rows, inner, cols); -} + err = cudaMemcpyAsync(d_src, src_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + cudaFree(d_src); + cudaFree(d_dst); + return set_error(err); + } + err = cudaMemcpyAsync(d_dst, dst_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + cudaFree(d_src); + cudaFree(d_dst); + return set_error(err); + } -extern "C" int gpu_block_mul_timed( - GpuPoly *const *out, - const GpuPoly *const *lhs, - const GpuPoly *const *rhs, - size_t rows, - size_t inner, - size_t cols, - double *out_kernel_ms) -{ - if (!out_kernel_ms) - { - return set_error("null out_kernel_ms in gpu_block_mul_timed"); - } - return gpu_block_matmul(out, lhs, rhs, rows, inner, cols, out_kernel_ms); -} + const int threads = 256; + const size_t total = count * n; + const int blocks = static_cast((total + threads - 1) / threads); -extern "C" int gpu_matrix_create( - GpuContext *ctx, - int level, - size_t rows, - size_t cols, - int format, - GpuMatrix **out) -{ - if (!ctx || !out) - { - return set_error("invalid gpu_matrix_create arguments"); - } - PolyFormat fmt; - if (!parse_format(format, fmt)) - { - return set_error("invalid format in gpu_matrix_create"); - } + matrix_decompose_kernel<<>>( + d_src, + d_dst, + count, + n, + shift, + mask, + out_modulus); + err = cudaGetLastError(); + if (err != cudaSuccess) + { + cudaFree(d_src); + cudaFree(d_dst); + return set_error(err); + } - auto *mat = new GpuMatrix{ctx, rows, cols, level, fmt, {}}; - const size_t count = rows * cols; - mat->polys.reserve(count); - for (size_t i = 0; i < count; ++i) - { - GpuPoly *poly = nullptr; - int status = gpu_poly_create(ctx, level, &poly); - if (status != 0) + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) { - for (auto *p : mat->polys) - { - gpu_poly_destroy(p); - } - delete mat; - return status; + cudaFree(d_src); + cudaFree(d_dst); + return set_error(err); } - poly->format = fmt; - mat->polys.push_back(poly); - } - *out = mat; - return 0; -} -extern "C" void gpu_matrix_destroy(GpuMatrix *mat) -{ - if (!mat) - { - return; - } - for (auto *poly : mat->polys) - { - gpu_poly_destroy(poly); + cudaFree(d_src); + cudaFree(d_dst); + return 0; } - delete mat; -} -extern "C" int gpu_matrix_copy(GpuMatrix *dst, const GpuMatrix *src) -{ - if (!dst || !src) - { - return set_error("invalid gpu_matrix_copy arguments"); - } - if (dst->rows != src->rows || dst->cols != src->cols) - { - return set_error("size mismatch in gpu_matrix_copy"); - } - if (dst->level != src->level || dst->ctx != src->ctx) + int launch_gauss_samp_gq_arb_base_kernel( + const std::vector &src_ptrs, + const std::vector &dst_ptrs, + size_t n, + uint64_t tower_modulus, + uint32_t base_bits, + uint32_t digits_per_tower, + uint32_t digit_idx, + double c, + uint32_t tower_idx, + uint64_t seed, + uint64_t out_modulus, + cudaStream_t stream) { - return set_error("context mismatch in gpu_matrix_copy"); + const size_t count = src_ptrs.size(); + if (count == 0 || n == 0) + { + return 0; + } + if (dst_ptrs.size() != count) + { + return set_error("unexpected pointer counts in matrix_gauss_samp_gq_arb_base_kernel"); + } + + const uint64_t **d_src = nullptr; + uint64_t **d_dst = nullptr; + const size_t bytes = count * sizeof(uint64_t *); + + cudaError_t err = cudaMalloc(&d_src, bytes); + if (err != cudaSuccess) + { + return set_error(err); + } + err = cudaMalloc(&d_dst, bytes); + if (err != cudaSuccess) + { + cudaFree(d_src); + return set_error(err); + } + + err = cudaMemcpyAsync(d_src, src_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + cudaFree(d_src); + cudaFree(d_dst); + return set_error(err); + } + err = cudaMemcpyAsync(d_dst, dst_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + cudaFree(d_src); + cudaFree(d_dst); + return set_error(err); + } + + const int threads = 256; + const size_t total = count * n; + const int blocks = static_cast((total + threads - 1) / threads); + + matrix_gauss_samp_gq_arb_base_kernel<<>>( + d_src, + d_dst, + count, + n, + tower_modulus, + base_bits, + digits_per_tower, + digit_idx, + c, + tower_idx, + seed, + out_modulus); + err = cudaGetLastError(); + if (err != cudaSuccess) + { + cudaFree(d_src); + cudaFree(d_dst); + return set_error(err); + } + + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) + { + cudaFree(d_src); + cudaFree(d_dst); + return set_error(err); + } + + cudaFree(d_src); + cudaFree(d_dst); + return 0; } - const size_t count = src->rows * src->cols; - for (size_t i = 0; i < count; ++i) + + int launch_sample_distribution_kernel( + const std::vector &dst_ptrs, + size_t n, + uint64_t modulus, + int dist_type, + double sigma, + uint32_t limb_idx, + uint64_t seed, + cudaStream_t stream) { - int status = gpu_poly_copy(dst->polys[i], src->polys[i]); - if (status != 0) + const size_t count = dst_ptrs.size(); + if (count == 0 || n == 0) { - return status; + return 0; + } + + uint64_t **d_dst = nullptr; + const size_t bytes = count * sizeof(uint64_t *); + cudaError_t err = cudaMalloc(&d_dst, bytes); + if (err != cudaSuccess) + { + return set_error(err); + } + + err = cudaMemcpyAsync(d_dst, dst_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + cudaFree(d_dst); + return set_error(err); + } + + const int threads = 256; + const size_t total = count * n; + const int blocks = static_cast((total + threads - 1) / threads); + matrix_sample_distribution_kernel<<>>( + d_dst, + count, + n, + modulus, + dist_type, + sigma, + limb_idx, + seed); + err = cudaGetLastError(); + if (err != cudaSuccess) + { + cudaFree(d_dst); + return set_error(err); + } + + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) + { + cudaFree(d_dst); + return set_error(err); } + + cudaFree(d_dst); + return 0; } - dst->format = src->format; - return 0; -} -extern "C" int gpu_matrix_entry_clone( - const GpuMatrix *mat, - size_t row, - size_t col, - GpuPoly **out_poly) -{ - if (!mat || !out_poly) + int launch_fill_gadget_kernel( + const std::vector &dst_ptrs, + size_t n, + uint64_t modulus, + size_t rows, + size_t cols, + size_t log_base_q, + uint32_t digits_per_tower, + uint32_t limb_idx, + uint32_t base_bits, + cudaStream_t stream) { - return set_error("invalid gpu_matrix_entry_clone arguments"); + const size_t count = dst_ptrs.size(); + if (count == 0 || n == 0) + { + return 0; + } + + uint64_t **d_dst = nullptr; + const size_t bytes = count * sizeof(uint64_t *); + cudaError_t err = cudaMalloc(&d_dst, bytes); + if (err != cudaSuccess) + { + return set_error(err); + } + + err = cudaMemcpyAsync(d_dst, dst_ptrs.data(), bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + cudaFree(d_dst); + return set_error(err); + } + + const int threads = 256; + const size_t total = count * n; + const int blocks = static_cast((total + threads - 1) / threads); + matrix_fill_gadget_kernel<<>>( + d_dst, + count, + n, + modulus, + rows, + cols, + log_base_q, + digits_per_tower, + limb_idx, + base_bits); + err = cudaGetLastError(); + if (err != cudaSuccess) + { + cudaFree(d_dst); + return set_error(err); + } + + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) + { + cudaFree(d_dst); + return set_error(err); + } + + cudaFree(d_dst); + return 0; } - if (row >= mat->rows || col >= mat->cols) + + int launch_sample_p1_full_kernel( + const std::vector &a_entries, + const std::vector &b_entries, + const std::vector &d_entries, + const std::vector &tp2_entries, + const std::vector &out_entries, + size_t d, + size_t cols, + size_t n, + uint64_t modulus, + double sigma, + double s, + double dgg_stddev, + uint32_t limb_idx, + uint64_t seed, + cudaStream_t stream, + int device_id) { - return set_error("index out of bounds in gpu_matrix_entry_clone"); - } - const size_t idx = matrix_index(row, col, mat->cols); - return gpu_poly_clone(mat->polys[idx], out_poly); -} + if (d == 0 || cols == 0 || n == 0) + { + return 0; + } + const size_t mat_entries = d * d; + const size_t vec_entries = 2 * d * cols; + if (a_entries.size() != mat_entries || b_entries.size() != mat_entries || + d_entries.size() != mat_entries || tp2_entries.size() != vec_entries || + out_entries.size() != vec_entries) + { + return set_error("unexpected pointer counts in matrix_sample_p1_full_kernel"); + } + + if (device_id < 0) + { + return set_error("invalid device in matrix_sample_p1_full_kernel"); + } + cudaError_t err = cudaSetDevice(device_id); + if (err != cudaSuccess) + { + return set_error(err); + } + + const size_t m = 2 * d; + if (m == 0) + { + return set_error("invalid dimension in matrix_sample_p1_full_kernel"); + } + if (m > std::numeric_limits::max() / m) + { + return set_error("workspace overflow in matrix_sample_p1_full_kernel"); + } + const size_t cov_elems_per_sample = m * m; + if (cov_elems_per_sample > std::numeric_limits::max() / sizeof(double)) + { + return set_error("workspace overflow in matrix_sample_p1_full_kernel"); + } + const size_t cov_bytes_per_sample = cov_elems_per_sample * sizeof(double); + if (m > std::numeric_limits::max() / sizeof(double)) + { + return set_error("workspace overflow in matrix_sample_p1_full_kernel"); + } + const size_t vec_bytes_per_sample = m * sizeof(double); + if (m > std::numeric_limits::max() / sizeof(int64_t)) + { + return set_error("workspace overflow in matrix_sample_p1_full_kernel"); + } + const size_t sampled_bytes_per_sample = m * sizeof(int64_t); + if (cov_bytes_per_sample > std::numeric_limits::max() - vec_bytes_per_sample) + { + return set_error("workspace overflow in matrix_sample_p1_full_kernel"); + } + size_t bytes_per_sample_total = cov_bytes_per_sample + vec_bytes_per_sample; + if (bytes_per_sample_total > std::numeric_limits::max() - vec_bytes_per_sample) + { + return set_error("workspace overflow in matrix_sample_p1_full_kernel"); + } + bytes_per_sample_total += vec_bytes_per_sample; + if (bytes_per_sample_total > std::numeric_limits::max() - sampled_bytes_per_sample) + { + return set_error("workspace overflow in matrix_sample_p1_full_kernel"); + } + bytes_per_sample_total += sampled_bytes_per_sample; + + const size_t total_samples = cols * n; + constexpr size_t kWorkspaceBudgetBytes = 128ULL * 1024ULL * 1024ULL; + size_t chunk_samples = kWorkspaceBudgetBytes / bytes_per_sample_total; + if (chunk_samples == 0) + { + chunk_samples = 1; + } + chunk_samples = std::min(chunk_samples, total_samples); + + const uint64_t **d_a_entries = nullptr; + const uint64_t **d_b_entries = nullptr; + const uint64_t **d_d_entries = nullptr; + const uint64_t **d_tp2_entries = nullptr; + uint64_t **d_out_entries = nullptr; + const size_t mat_bytes = mat_entries * sizeof(uint64_t *); + const size_t vec_bytes = vec_entries * sizeof(uint64_t *); + + auto free_all = [&]() { + if (d_a_entries) + cudaFree(d_a_entries); + if (d_b_entries) + cudaFree(d_b_entries); + if (d_d_entries) + cudaFree(d_d_entries); + if (d_tp2_entries) + cudaFree(d_tp2_entries); + if (d_out_entries) + cudaFree(d_out_entries); + }; + + err = cudaMalloc(&d_a_entries, mat_bytes); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + err = cudaMalloc(&d_b_entries, mat_bytes); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + err = cudaMalloc(&d_d_entries, mat_bytes); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + err = cudaMalloc(&d_tp2_entries, vec_bytes); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + err = cudaMalloc(&d_out_entries, vec_bytes); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + + err = cudaMemcpyAsync(d_a_entries, a_entries.data(), mat_bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + err = cudaMemcpyAsync(d_b_entries, b_entries.data(), mat_bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + err = cudaMemcpyAsync(d_d_entries, d_entries.data(), mat_bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + err = cudaMemcpyAsync(d_tp2_entries, tp2_entries.data(), vec_bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + err = cudaMemcpyAsync(d_out_entries, out_entries.data(), vec_bytes, cudaMemcpyHostToDevice, stream); + if (err != cudaSuccess) + { + free_all(); + return set_error(err); + } + + double *cov_workspace = nullptr; + double *mean_workspace = nullptr; + double *col_workspace = nullptr; + int64_t *sampled_workspace = nullptr; + auto free_workspace = [&]() { + if (cov_workspace) + cudaFree(cov_workspace); + if (mean_workspace) + cudaFree(mean_workspace); + if (col_workspace) + cudaFree(col_workspace); + if (sampled_workspace) + cudaFree(sampled_workspace); + cov_workspace = nullptr; + mean_workspace = nullptr; + col_workspace = nullptr; + sampled_workspace = nullptr; + }; + + auto alloc_workspace = [&](size_t samples) -> bool { + if (samples == 0) + { + return false; + } + if (samples > std::numeric_limits::max() / cov_bytes_per_sample || + samples > std::numeric_limits::max() / vec_bytes_per_sample || + samples > std::numeric_limits::max() / sampled_bytes_per_sample) + { + return false; + } + cudaError_t local_err = cudaMalloc(&cov_workspace, samples * cov_bytes_per_sample); + if (local_err != cudaSuccess) + { + free_workspace(); + return false; + } + local_err = cudaMalloc(&mean_workspace, samples * vec_bytes_per_sample); + if (local_err != cudaSuccess) + { + free_workspace(); + return false; + } + local_err = cudaMalloc(&col_workspace, samples * vec_bytes_per_sample); + if (local_err != cudaSuccess) + { + free_workspace(); + return false; + } + local_err = cudaMalloc(&sampled_workspace, samples * sampled_bytes_per_sample); + if (local_err != cudaSuccess) + { + free_workspace(); + return false; + } + return true; + }; + + while (!alloc_workspace(chunk_samples)) + { + if (chunk_samples <= 1) + { + free_all(); + return set_error("failed to allocate workspace in matrix_sample_p1_full_kernel"); + } + chunk_samples = (chunk_samples + 1) / 2; + } + + const int threads = 256; + for (size_t sample_start = 0; sample_start < total_samples; sample_start += chunk_samples) + { + size_t sample_count = std::min(chunk_samples, total_samples - sample_start); + const int blocks = static_cast((sample_count + threads - 1) / threads); + matrix_sample_p1_full_kernel<<>>( + d_a_entries, + d_b_entries, + d_d_entries, + d_tp2_entries, + d_out_entries, + d, + cols, + n, + sample_start, + sample_count, + cov_workspace, + mean_workspace, + col_workspace, + sampled_workspace, + modulus, + sigma, + s, + dgg_stddev, + limb_idx, + seed); + err = cudaGetLastError(); + if (err != cudaSuccess) + { + free_workspace(); + free_all(); + return set_error(err); + } + } + + err = cudaStreamSynchronize(stream); + if (err != cudaSuccess) + { + free_workspace(); + free_all(); + return set_error(err); + } + + free_workspace(); + free_all(); + return 0; + } + + template + int launch_for_limb( + GpuPoly *const *out, + const GpuPoly *const *lhs, + const GpuPoly *const *rhs, + size_t count, + size_t n, + int limb, + const dim3 &limb_id, + BlockOp op) + { + std::vector out_ptrs; + std::vector lhs_ptrs; + std::vector rhs_ptrs; + out_ptrs.reserve(count); + lhs_ptrs.reserve(count); + rhs_ptrs.reserve(count); + + cudaStream_t stream = nullptr; + for (size_t i = 0; i < count; ++i) + { + auto &lhs_partition = lhs[i]->poly->GPU[limb_id.x]; + auto &rhs_partition = rhs[i]->poly->GPU[limb_id.x]; + auto &out_partition = out[i]->poly->GPU[limb_id.x]; + if (limb_id.y >= lhs_partition.limb.size() || limb_id.y >= rhs_partition.limb.size() || + limb_id.y >= out_partition.limb.size()) + { + return set_error("unexpected limb index in gpu_block_op"); + } + auto &lhs_limb = lhs_partition.limb[limb_id.y]; + auto &rhs_limb = rhs_partition.limb[limb_id.y]; + auto &out_limb = out_partition.limb[limb_id.y]; + if (lhs_limb.index() != rhs_limb.index() || lhs_limb.index() != out_limb.index()) + { + return set_error("mixed limb types in gpu_block_op"); + } + + if constexpr (std::is_same_v) + { + auto &lhs_u = std::get(lhs_limb); + auto &rhs_u = std::get(rhs_limb); + auto &out_u = std::get(out_limb); + if (!stream) + { + stream = lhs_u.stream.ptr; + } + lhs_ptrs.push_back(lhs_u.v.data); + rhs_ptrs.push_back(rhs_u.v.data); + out_ptrs.push_back(out_u.v.data); + } + else + { + auto &lhs_u = std::get(lhs_limb); + auto &rhs_u = std::get(rhs_limb); + auto &out_u = std::get(out_limb); + if (!stream) + { + stream = lhs_u.stream.ptr; + } + lhs_ptrs.push_back(lhs_u.v.data); + rhs_ptrs.push_back(rhs_u.v.data); + out_ptrs.push_back(out_u.v.data); + } + } + + const uint64_t modulus64 = lhs[0]->ctx->moduli[static_cast(limb)]; + const T modulus = static_cast(modulus64); + return launch_block_kernel(out_ptrs, lhs_ptrs, rhs_ptrs, n, modulus, op, stream); + } + + template + int launch_for_limb_matmul( + GpuPoly *const *out, + const GpuPoly *const *lhs, + const GpuPoly *const *rhs, + size_t rows, + size_t inner, + size_t cols, + size_t n, + int limb, + const dim3 &limb_id, + double *out_kernel_ms) + { + const size_t out_count = rows * cols; + const size_t lhs_count = rows * inner; + const size_t rhs_count = inner * cols; + + std::vector out_ptrs; + std::vector lhs_ptrs; + std::vector rhs_ptrs; + out_ptrs.reserve(out_count); + lhs_ptrs.reserve(lhs_count); + rhs_ptrs.reserve(rhs_count); + + cudaStream_t stream = nullptr; + int limb_index = -1; + + for (size_t i = 0; i < lhs_count; ++i) + { + auto &lhs_partition = lhs[i]->poly->GPU[limb_id.x]; + if (limb_id.y >= lhs_partition.limb.size()) + { + return set_error("unexpected limb index in gpu_block_mul"); + } + auto &lhs_limb = lhs_partition.limb[limb_id.y]; + if constexpr (std::is_same_v) + { + if (lhs_limb.index() != FIDESlib::U64) + { + return set_error("mixed limb types in gpu_block_mul"); + } + auto &lhs_u = std::get(lhs_limb); + if (!stream) + { + stream = lhs_u.stream.ptr; + } + lhs_ptrs.push_back(lhs_u.v.data); + } + else + { + if (lhs_limb.index() != FIDESlib::U32) + { + return set_error("mixed limb types in gpu_block_mul"); + } + auto &lhs_u = std::get(lhs_limb); + if (!stream) + { + stream = lhs_u.stream.ptr; + } + lhs_ptrs.push_back(lhs_u.v.data); + } + if (limb_index == -1) + { + limb_index = lhs_limb.index(); + } + } + + for (size_t i = 0; i < rhs_count; ++i) + { + auto &rhs_partition = rhs[i]->poly->GPU[limb_id.x]; + if (limb_id.y >= rhs_partition.limb.size()) + { + return set_error("unexpected limb index in gpu_block_mul"); + } + auto &rhs_limb = rhs_partition.limb[limb_id.y]; + if (rhs_limb.index() != limb_index) + { + return set_error("mixed limb types in gpu_block_mul"); + } + + if constexpr (std::is_same_v) + { + auto &rhs_u = std::get(rhs_limb); + rhs_ptrs.push_back(rhs_u.v.data); + } + else + { + auto &rhs_u = std::get(rhs_limb); + rhs_ptrs.push_back(rhs_u.v.data); + } + } + + for (size_t i = 0; i < out_count; ++i) + { + auto &out_partition = out[i]->poly->GPU[limb_id.x]; + if (limb_id.y >= out_partition.limb.size()) + { + return set_error("unexpected limb index in gpu_block_mul"); + } + auto &out_limb = out_partition.limb[limb_id.y]; + if (out_limb.index() != limb_index) + { + return set_error("mixed limb types in gpu_block_mul"); + } + + if constexpr (std::is_same_v) + { + auto &out_u = std::get(out_limb); + out_ptrs.push_back(out_u.v.data); + } + else + { + auto &out_u = std::get(out_limb); + out_ptrs.push_back(out_u.v.data); + } + } + + const uint64_t modulus64 = lhs[0]->ctx->moduli[static_cast(limb)]; + const T modulus = static_cast(modulus64); + return launch_block_matmul_kernel(out_ptrs, lhs_ptrs, rhs_ptrs, rows, inner, cols, n, modulus, stream, out_kernel_ms); + } + + int gpu_block_op(GpuPoly *const *out, const GpuPoly *const *lhs, const GpuPoly *const *rhs, size_t count, BlockOp op) + { + if (!out || !lhs || !rhs) + { + return set_error("invalid gpu_block_op arguments"); + } + if (count == 0) + { + return 0; + } + + const GpuPoly *lhs0 = lhs[0]; + if (!lhs0 || !lhs0->ctx) + { + return set_error("null context in gpu_block_op"); + } + const GpuContext *ctx = lhs0->ctx; + const int level = lhs0->level; + const PolyFormat format = lhs0->format; + if (op == BlockOp::Mul && format != PolyFormat::Eval) + { + return set_error("gpu_block_entrywise_mul requires Eval format"); + } + + // for (size_t i = 0; i < count; ++i) + // { + // if (!out[i] || !lhs[i] || !rhs[i]) + // { + // return set_error("null polynomial in gpu_block_op"); + // } + // if (lhs[i]->ctx != ctx || rhs[i]->ctx != ctx || out[i]->ctx != ctx) + // { + // return set_error("mismatched contexts in gpu_block_op"); + // } + // if (lhs[i]->level != level || rhs[i]->level != level || out[i]->level != level) + // { + // return set_error("mismatched levels in gpu_block_op"); + // } + // if (lhs[i]->format != format || rhs[i]->format != format) + // { + // return set_error("mismatched formats in gpu_block_op"); + // } + // } + + if (level < 0) + { + return set_error("invalid level in gpu_block_op"); + } + + const int N = ctx->N; + if (N <= 0) + { + return 0; + } + + auto &limb_map = ctx->ctx->limbGPUid; + if (limb_map.size() < static_cast(level + 1)) + { + return set_error("unexpected limb mapping size in gpu_block_op"); + } + + for (int limb = 0; limb <= level; ++limb) + { + const dim3 limb_id = limb_map[static_cast(limb)]; + if (limb_id.x >= lhs0->poly->GPU.size()) + { + return set_error("unexpected limb GPU partition in gpu_block_op"); + } + const auto &partition = lhs0->poly->GPU[limb_id.x]; + if (limb_id.y >= partition.limb.size()) + { + return set_error("unexpected limb index in gpu_block_op"); + } + + cudaError_t err = cudaSetDevice(partition.device); + if (err != cudaSuccess) + { + return set_error(err); + } + + const auto &limb_impl = partition.limb[limb_id.y]; + if (limb_impl.index() == FIDESlib::U64) + { + int status = launch_for_limb(out, lhs, rhs, count, static_cast(N), limb, limb_id, op); + if (status != 0) + { + return status; + } + } + else if (limb_impl.index() == FIDESlib::U32) + { + int status = launch_for_limb(out, lhs, rhs, count, static_cast(N), limb, limb_id, op); + if (status != 0) + { + return status; + } + } + else + { + return set_error("unsupported limb type in gpu_block_op"); + } + } + + for (size_t i = 0; i < count; ++i) + { + out[i]->level = level; + out[i]->format = format; + } + return 0; + } + + int gpu_block_matmul( + GpuPoly *const *out, + const GpuPoly *const *lhs, + const GpuPoly *const *rhs, + size_t rows, + size_t inner, + size_t cols, + double *out_kernel_ms = nullptr) + { + if (!out || !lhs || !rhs) + { + return set_error("invalid gpu_block_mul arguments"); + } + if (out_kernel_ms) + { + *out_kernel_ms = 0.0; + } + const size_t out_count = rows * cols; + if (rows == 0 || inner == 0 || cols == 0) + { + return 0; + } + if (out_count == 0) + { + return 0; + } + + const GpuPoly *lhs0 = lhs[0]; + if (!lhs0 || !lhs0->ctx) + { + return set_error("null context in gpu_block_mul"); + } + const GpuContext *ctx = lhs0->ctx; + const int level = lhs0->level; + const PolyFormat format = lhs0->format; + if (format != PolyFormat::Eval) + { + return set_error("gpu_block_mul requires Eval format"); + } + + if (level < 0) + { + return set_error("invalid level in gpu_block_mul"); + } + + const int N = ctx->N; + if (N <= 0) + { + return 0; + } + + auto &limb_map = ctx->ctx->limbGPUid; + if (limb_map.size() < static_cast(level + 1)) + { + return set_error("unexpected limb mapping size in gpu_block_mul"); + } + + for (int limb = 0; limb <= level; ++limb) + { + const dim3 limb_id = limb_map[static_cast(limb)]; + if (limb_id.x >= lhs0->poly->GPU.size()) + { + return set_error("unexpected limb GPU partition in gpu_block_mul"); + } + const auto &partition = lhs0->poly->GPU[limb_id.x]; + if (limb_id.y >= partition.limb.size()) + { + return set_error("unexpected limb index in gpu_block_mul"); + } + + cudaError_t err = cudaSetDevice(partition.device); + if (err != cudaSuccess) + { + return set_error(err); + } + + const auto &limb_impl = partition.limb[limb_id.y]; + if (limb_impl.index() == FIDESlib::U64) + { + int status = + launch_for_limb_matmul( + out, + lhs, + rhs, + rows, + inner, + cols, + static_cast(N), + limb, + limb_id, + out_kernel_ms); + if (status != 0) + { + return status; + } + } + else if (limb_impl.index() == FIDESlib::U32) + { + int status = + launch_for_limb_matmul( + out, + lhs, + rhs, + rows, + inner, + cols, + static_cast(N), + limb, + limb_id, + out_kernel_ms); + if (status != 0) + { + return status; + } + } + else + { + return set_error("unsupported limb type in gpu_block_mul"); + } + } + + for (size_t i = 0; i < out_count; ++i) + { + out[i]->level = level; + out[i]->format = format; + } + return 0; + } +} // namespace + +extern "C" int gpu_block_add(GpuPoly *const *out, const GpuPoly *const *lhs, const GpuPoly *const *rhs, size_t count) +{ + return gpu_block_op(out, lhs, rhs, count, BlockOp::Add); +} + +extern "C" int gpu_block_sub(GpuPoly *const *out, const GpuPoly *const *lhs, const GpuPoly *const *rhs, size_t count) +{ + return gpu_block_op(out, lhs, rhs, count, BlockOp::Sub); +} + +extern "C" int gpu_block_entrywise_mul( + GpuPoly *const *out, + const GpuPoly *const *lhs, + const GpuPoly *const *rhs, + size_t count) +{ + return gpu_block_op(out, lhs, rhs, count, BlockOp::Mul); +} + +extern "C" int gpu_block_mul( + GpuPoly *const *out, + const GpuPoly *const *lhs, + const GpuPoly *const *rhs, + size_t rows, + size_t inner, + size_t cols) +{ + return gpu_block_matmul(out, lhs, rhs, rows, inner, cols); +} + +extern "C" int gpu_block_mul_timed( + GpuPoly *const *out, + const GpuPoly *const *lhs, + const GpuPoly *const *rhs, + size_t rows, + size_t inner, + size_t cols, + double *out_kernel_ms) +{ + if (!out_kernel_ms) + { + return set_error("null out_kernel_ms in gpu_block_mul_timed"); + } + return gpu_block_matmul(out, lhs, rhs, rows, inner, cols, out_kernel_ms); +} + +extern "C" int gpu_matrix_create( + GpuContext *ctx, + int level, + size_t rows, + size_t cols, + int format, + GpuMatrix **out) +{ + if (!ctx || !out) + { + return set_error("invalid gpu_matrix_create arguments"); + } + PolyFormat fmt; + if (!parse_format(format, fmt)) + { + return set_error("invalid format in gpu_matrix_create"); + } + + auto *mat = new GpuMatrix{ctx, rows, cols, level, fmt, {}}; + const size_t count = rows * cols; + mat->polys.reserve(count); + for (size_t i = 0; i < count; ++i) + { + GpuPoly *poly = nullptr; + int status = gpu_poly_create(ctx, level, &poly); + if (status != 0) + { + for (auto *p : mat->polys) + { + gpu_poly_destroy(p); + } + delete mat; + return status; + } + poly->format = fmt; + mat->polys.push_back(poly); + } + *out = mat; + return 0; +} + +extern "C" void gpu_matrix_destroy(GpuMatrix *mat) +{ + if (!mat) + { + return; + } + for (auto *poly : mat->polys) + { + gpu_poly_destroy(poly); + } + delete mat; +} + +extern "C" int gpu_matrix_copy(GpuMatrix *dst, const GpuMatrix *src) +{ + if (!dst || !src) + { + return set_error("invalid gpu_matrix_copy arguments"); + } + if (dst->rows != src->rows || dst->cols != src->cols) + { + return set_error("size mismatch in gpu_matrix_copy"); + } + if (dst->level != src->level || dst->ctx != src->ctx) + { + return set_error("context mismatch in gpu_matrix_copy"); + } + const size_t count = src->rows * src->cols; + for (size_t i = 0; i < count; ++i) + { + int status = gpu_poly_copy(dst->polys[i], src->polys[i]); + if (status != 0) + { + return status; + } + } + dst->format = src->format; + return 0; +} + +extern "C" int gpu_matrix_entry_clone( + const GpuMatrix *mat, + size_t row, + size_t col, + GpuPoly **out_poly) +{ + if (!mat || !out_poly) + { + return set_error("invalid gpu_matrix_entry_clone arguments"); + } + if (row >= mat->rows || col >= mat->cols) + { + return set_error("index out of bounds in gpu_matrix_entry_clone"); + } + const size_t idx = matrix_index(row, col, mat->cols); + return gpu_poly_clone(mat->polys[idx], out_poly); +} extern "C" int gpu_matrix_copy_entry( GpuMatrix *mat, @@ -1232,627 +2458,1864 @@ extern "C" int gpu_matrix_copy_entry( size_t col, const GpuPoly *src) { - if (!mat || !src) + if (!mat || !src) + { + return set_error("invalid gpu_matrix_copy_entry arguments"); + } + if (row >= mat->rows || col >= mat->cols) + { + return set_error("index out of bounds in gpu_matrix_copy_entry"); + } + if (src->ctx != mat->ctx || src->level != mat->level) + { + return set_error("context mismatch in gpu_matrix_copy_entry"); + } + const size_t idx = matrix_index(row, col, mat->cols); + return gpu_poly_copy(mat->polys[idx], src); +} + +extern "C" int gpu_matrix_load_rns_batch( + GpuMatrix *mat, + const uint8_t *bytes, + size_t bytes_per_poly, + int format) +{ + if (!mat) + { + return set_error("invalid gpu_matrix_load_rns_batch arguments"); + } + const size_t count = mat->rows * mat->cols; + int status = gpu_poly_load_rns_batch( + mat->polys.data(), + count, + bytes, + bytes_per_poly, + format); + if (status != 0) + { + return status; + } + PolyFormat fmt; + if (!parse_format(format, fmt)) + { + return set_error("invalid format in gpu_matrix_load_rns_batch"); + } + mat->format = fmt; + return 0; +} + +extern "C" int gpu_matrix_store_rns_batch( + const GpuMatrix *mat, + uint8_t *bytes_out, + size_t bytes_per_poly, + int format, + GpuEventSet **out_events) +{ + if (!mat) + { + return set_error("invalid gpu_matrix_store_rns_batch arguments"); + } + const size_t count = mat->rows * mat->cols; + return gpu_poly_store_rns_batch( + const_cast(mat->polys.data()), + count, + bytes_out, + bytes_per_poly, + format, + out_events); +} + +extern "C" int gpu_matrix_add(GpuMatrix *out, const GpuMatrix *lhs, const GpuMatrix *rhs) +{ + if (!out || !lhs || !rhs) + { + return set_error("invalid gpu_matrix_add arguments"); + } + if (lhs->rows != rhs->rows || lhs->cols != rhs->cols) + { + return set_error("size mismatch in gpu_matrix_add"); + } + if (out->rows != lhs->rows || out->cols != lhs->cols) + { + return set_error("output size mismatch in gpu_matrix_add"); + } + if (lhs->ctx != rhs->ctx || lhs->ctx != out->ctx || lhs->level != rhs->level || + lhs->level != out->level) + { + return set_error("context mismatch in gpu_matrix_add"); + } + const size_t count = lhs->rows * lhs->cols; + int status = gpu_block_add( + out->polys.data(), + const_cast(lhs->polys.data()), + const_cast(rhs->polys.data()), + count); + if (status != 0) + { + return status; + } + out->format = PolyFormat::Eval; + return 0; +} + +extern "C" int gpu_matrix_sub(GpuMatrix *out, const GpuMatrix *lhs, const GpuMatrix *rhs) +{ + if (!out || !lhs || !rhs) + { + return set_error("invalid gpu_matrix_sub arguments"); + } + if (lhs->rows != rhs->rows || lhs->cols != rhs->cols) + { + return set_error("size mismatch in gpu_matrix_sub"); + } + if (out->rows != lhs->rows || out->cols != lhs->cols) + { + return set_error("output size mismatch in gpu_matrix_sub"); + } + if (lhs->ctx != rhs->ctx || lhs->ctx != out->ctx || lhs->level != rhs->level || + lhs->level != out->level) + { + return set_error("context mismatch in gpu_matrix_sub"); + } + const size_t count = lhs->rows * lhs->cols; + int status = gpu_block_sub( + out->polys.data(), + const_cast(lhs->polys.data()), + const_cast(rhs->polys.data()), + count); + if (status != 0) + { + return status; + } + out->format = lhs->format; + return 0; +} + +extern "C" int gpu_matrix_mul(GpuMatrix *out, const GpuMatrix *lhs, const GpuMatrix *rhs) +{ + if (!out || !lhs || !rhs) + { + return set_error("invalid gpu_matrix_mul arguments"); + } + if (lhs->cols != rhs->rows) + { + return set_error("size mismatch in gpu_matrix_mul"); + } + if (out->rows != lhs->rows || out->cols != rhs->cols) + { + return set_error("output size mismatch in gpu_matrix_mul"); + } + if (lhs->ctx != rhs->ctx || lhs->ctx != out->ctx || lhs->level != rhs->level || + lhs->level != out->level) + { + return set_error("context mismatch in gpu_matrix_mul"); + } + int status = gpu_block_mul( + out->polys.data(), + const_cast(lhs->polys.data()), + const_cast(rhs->polys.data()), + lhs->rows, + lhs->cols, + rhs->cols); + if (status != 0) + { + return status; + } + out->format = PolyFormat::Eval; + return 0; +} + +extern "C" int gpu_matrix_equal(const GpuMatrix *lhs, const GpuMatrix *rhs, int *out_equal) +{ + if (!lhs || !rhs || !out_equal) + { + return set_error("invalid gpu_matrix_equal arguments"); + } + *out_equal = 0; + + if (lhs == rhs) + { + *out_equal = 1; + return 0; + } + if (lhs->rows != rhs->rows || lhs->cols != rhs->cols) + { + return 0; + } + if (lhs->ctx != rhs->ctx || lhs->level != rhs->level) + { + return 0; + } + + const size_t count = lhs->rows * lhs->cols; + for (size_t i = 0; i < count; ++i) + { + int poly_equal = 0; + int status = gpu_poly_equal(lhs->polys[i], rhs->polys[i], &poly_equal); + if (status != 0) + { + return status; + } + if (poly_equal == 0) + { + return 0; + } + } + + *out_equal = 1; + return 0; +} + +extern "C" int gpu_matrix_mul_timed( + GpuMatrix *out, + const GpuMatrix *lhs, + const GpuMatrix *rhs, + double *out_kernel_ms) +{ + if (!out_kernel_ms) + { + return set_error("null out_kernel_ms in gpu_matrix_mul_timed"); + } + if (!out || !lhs || !rhs) + { + return set_error("invalid gpu_matrix_mul_timed arguments"); + } + if (lhs->cols != rhs->rows) + { + return set_error("size mismatch in gpu_matrix_mul_timed"); + } + if (out->rows != lhs->rows || out->cols != rhs->cols) + { + return set_error("output size mismatch in gpu_matrix_mul_timed"); + } + if (lhs->ctx != rhs->ctx || lhs->ctx != out->ctx || lhs->level != rhs->level || + lhs->level != out->level) + { + return set_error("context mismatch in gpu_matrix_mul_timed"); + } + int status = gpu_block_mul_timed( + out->polys.data(), + const_cast(lhs->polys.data()), + const_cast(rhs->polys.data()), + lhs->rows, + lhs->cols, + rhs->cols, + out_kernel_ms); + if (status != 0) + { + return status; + } + out->format = PolyFormat::Eval; + return 0; +} + +extern "C" int gpu_matrix_mul_scalar( + GpuMatrix *out, + const GpuMatrix *lhs, + const GpuPoly *scalar) +{ + if (!out || !lhs || !scalar) + { + return set_error("invalid gpu_matrix_mul_scalar arguments"); + } + if (out->rows != lhs->rows || out->cols != lhs->cols) + { + return set_error("output size mismatch in gpu_matrix_mul_scalar"); + } + if (lhs->ctx != out->ctx || lhs->level != out->level) + { + return set_error("context mismatch in gpu_matrix_mul_scalar"); + } + if (scalar->ctx != lhs->ctx || scalar->level != lhs->level) + { + return set_error("scalar context mismatch in gpu_matrix_mul_scalar"); + } + + const size_t count = lhs->rows * lhs->cols; + std::vector rhs(count, scalar); + int status = gpu_block_entrywise_mul( + out->polys.data(), + const_cast(lhs->polys.data()), + rhs.data(), + count); + if (status != 0) + { + return status; + } + out->format = lhs->format; + return 0; +} + +extern "C" int gpu_matrix_copy_block( + GpuMatrix *out, + const GpuMatrix *src, + size_t dst_row, + size_t dst_col, + size_t src_row, + size_t src_col, + size_t rows, + size_t cols) +{ + if (!out || !src) + { + return set_error("invalid gpu_matrix_copy_block arguments"); + } + if (src_row + rows > src->rows || src_col + cols > src->cols) + { + return set_error("source bounds exceeded in gpu_matrix_copy_block"); + } + if (dst_row + rows > out->rows || dst_col + cols > out->cols) + { + return set_error("dest bounds exceeded in gpu_matrix_copy_block"); + } + if (src->ctx != out->ctx || src->level != out->level) + { + return set_error("context mismatch in gpu_matrix_copy_block"); + } + + for (size_t i = 0; i < rows; ++i) + { + for (size_t j = 0; j < cols; ++j) + { + const size_t src_idx = matrix_index(src_row + i, src_col + j, src->cols); + const size_t dst_idx = matrix_index(dst_row + i, dst_col + j, out->cols); + int status = gpu_poly_copy(out->polys[dst_idx], src->polys[src_idx]); + if (status != 0) + { + return status; + } + } + } + out->format = src->format; + return 0; +} + +extern "C" int gpu_matrix_fill_gadget( + GpuMatrix *out, + uint32_t base_bits) +{ + if (!out) + { + return set_error("invalid gpu_matrix_fill_gadget arguments"); + } + if (base_bits == 0 || base_bits >= 63) + { + return set_error("invalid base_bits in gpu_matrix_fill_gadget"); + } + + const size_t rows = out->rows; + const size_t cols = out->cols; + const size_t count = rows * cols; + if (count == 0) + { + out->format = PolyFormat::Eval; + return 0; + } + + const int level = out->level; + if (level < 0) + { + return set_error("invalid level in gpu_matrix_fill_gadget"); + } + const size_t crt_depth = static_cast(level + 1); + if (out->ctx->moduli.size() < crt_depth) + { + return set_error("unexpected modulus count in gpu_matrix_fill_gadget"); + } + auto &limb_map = out->ctx->ctx->limbGPUid; + if (limb_map.size() < crt_depth) + { + return set_error("unexpected limb mapping size in gpu_matrix_fill_gadget"); + } + + uint32_t crt_bits = 0; + for (size_t i = 0; i < crt_depth; ++i) + { + crt_bits = std::max(crt_bits, bit_width_u64(out->ctx->moduli[i])); + } + if (crt_bits == 0) + { + return set_error("invalid crt_bits in gpu_matrix_fill_gadget"); + } + const uint32_t digits_per_tower = static_cast((crt_bits + base_bits - 1) / base_bits); + if (digits_per_tower == 0) + { + return set_error("invalid digits_per_tower in gpu_matrix_fill_gadget"); + } + const size_t log_base_q = static_cast(digits_per_tower) * crt_depth; + if (cols != rows * log_base_q) + { + return set_error("output size mismatch in gpu_matrix_fill_gadget"); + } + + for (int limb = 0; limb <= level; ++limb) + { + const dim3 limb_id = limb_map[static_cast(limb)]; + std::vector dst_ptrs; + dst_ptrs.reserve(count); + cudaStream_t out_stream = nullptr; + + for (size_t idx = 0; idx < count; ++idx) + { + GpuPoly *poly = out->polys[idx]; + if (!poly || poly->ctx != out->ctx || poly->level != level) + { + return set_error("invalid output poly in gpu_matrix_fill_gadget"); + } + if (limb_id.x >= poly->poly->GPU.size()) + { + return set_error("unexpected limb GPU partition in gpu_matrix_fill_gadget"); + } + auto &partition = poly->poly->GPU[limb_id.x]; + if (limb_id.y >= partition.limb.size()) + { + return set_error("unexpected limb index in gpu_matrix_fill_gadget"); + } + auto &limb_impl = partition.limb[limb_id.y]; + if (limb_impl.index() != FIDESlib::U64) + { + return set_error("unsupported limb type in gpu_matrix_fill_gadget"); + } + auto &limb_u64 = std::get(limb_impl); + if (!out_stream) + { + out_stream = limb_u64.stream.ptr; + } + dst_ptrs.push_back(limb_u64.v.data); + } + + int status = launch_fill_gadget_kernel( + dst_ptrs, + static_cast(out->ctx->N), + out->ctx->moduli[static_cast(limb)], + rows, + cols, + log_base_q, + digits_per_tower, + static_cast(limb), + base_bits, + out_stream); + if (status != 0) + { + return status; + } + } + + const int batch = default_batch(out->ctx); + for (auto *poly : out->polys) + { + poly->format = PolyFormat::Coeff; + int status = gpu_poly_ntt(poly, batch); + if (status != 0) + { + return status; + } + poly->format = PolyFormat::Eval; + } + out->format = PolyFormat::Eval; + return 0; +} + +extern "C" int gpu_matrix_decompose_base(const GpuMatrix *src, uint32_t base_bits, GpuMatrix *out) +{ + if (!src || !out) + { + return set_error("invalid gpu_matrix_decompose_base arguments"); + } + if (base_bits == 0) + { + return set_error("base_bits must be non-zero in gpu_matrix_decompose_base"); + } + if (src->ctx != out->ctx || src->level != out->level) + { + return set_error("context mismatch in gpu_matrix_decompose_base"); + } + + const size_t rows = src->rows; + const size_t cols = src->cols; + const size_t count = rows * cols; + const int level = src->level; + if (level < 0) + { + return set_error("invalid level in gpu_matrix_decompose_base"); + } + const size_t crt_depth = static_cast(level + 1); + uint32_t crt_bits = 0; + for (const auto &modulus : src->ctx->moduli) + { + crt_bits = std::max(crt_bits, bit_width_u64(modulus)); + } + if (crt_bits == 0) + { + return set_error("invalid crt_bits in gpu_matrix_decompose_base"); + } + const uint32_t digits_per_tower = + static_cast((crt_bits + base_bits - 1) / base_bits); + if (digits_per_tower == 0) + { + return set_error("invalid digits_per_tower in gpu_matrix_decompose_base"); + } + const size_t log_base_q = static_cast(digits_per_tower) * crt_depth; + if (out->rows != rows * log_base_q || out->cols != cols) + { + return set_error("output size mismatch in gpu_matrix_decompose_base"); + } + if (count == 0) + { + out->format = PolyFormat::Eval; + return 0; + } + + std::vector tmp_inputs; + std::vector inputs; + inputs.reserve(count); + const int batch = default_batch(src->ctx); + if (src->format == PolyFormat::Eval) + { + tmp_inputs.reserve(count); + for (size_t i = 0; i < count; ++i) + { + GpuPoly *clone = nullptr; + int status = gpu_poly_clone(src->polys[i], &clone); + if (status != 0) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return status; + } + status = gpu_poly_intt(clone, batch); + if (status != 0) + { + gpu_poly_destroy(clone); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return status; + } + tmp_inputs.push_back(clone); + inputs.push_back(clone); + } + } + else + { + for (size_t i = 0; i < count; ++i) + { + inputs.push_back(src->polys[i]); + } + } + + // Ensure all pending INTT work has completed before reading source limbs + // from different streams in the sampling kernels. + for (int device : src->ctx->gpu_ids) + { + cudaError_t err = cudaSetDevice(device); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + } + + auto &limb_map = src->ctx->ctx->limbGPUid; + if (limb_map.size() < crt_depth) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected limb mapping size in gpu_matrix_decompose_base"); + } + + for (size_t idx = 0; idx < out->polys.size(); ++idx) + { + GpuPoly *poly = out->polys[idx]; + if (!poly || poly->ctx != src->ctx || poly->level != level) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("invalid output poly in gpu_matrix_decompose_base"); + } + + for (int limb = 0; limb <= level; ++limb) + { + const dim3 limb_id = limb_map[static_cast(limb)]; + if (limb_id.x >= poly->poly->GPU.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected limb GPU partition in gpu_matrix_decompose_base"); + } + auto &partition = poly->poly->GPU[limb_id.x]; + if (limb_id.y >= partition.limb.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected limb index in gpu_matrix_decompose_base"); + } + auto &limb_impl = partition.limb[limb_id.y]; + if (limb_impl.index() != FIDESlib::U64) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unsupported limb type in gpu_matrix_decompose_base"); + } + auto &limb_u64 = std::get(limb_impl); + + cudaError_t err = cudaSetDevice(partition.device); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + err = cudaMemsetAsync( + limb_u64.v.data, + 0, + static_cast(src->ctx->N) * sizeof(uint64_t), + limb_u64.stream.ptr); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + } + poly->format = PolyFormat::Coeff; + poly->level = level; + } + + if (src->ctx->moduli.size() < crt_depth) { - return set_error("invalid gpu_matrix_copy_entry arguments"); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected modulus count in gpu_matrix_decompose_base"); } - if (row >= mat->rows || col >= mat->cols) + + for (int src_limb = 0; src_limb <= level; ++src_limb) { - return set_error("index out of bounds in gpu_matrix_copy_entry"); + const dim3 src_limb_id = limb_map[static_cast(src_limb)]; + if (src_limb_id.x >= inputs[0]->poly->GPU.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected source limb GPU partition in gpu_matrix_decompose_base"); + } + const auto &src_partition0 = inputs[0]->poly->GPU[src_limb_id.x]; + if (src_limb_id.y >= src_partition0.limb.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected source limb index in gpu_matrix_decompose_base"); + } + const auto &src_limb_impl0 = src_partition0.limb[src_limb_id.y]; + if (src_limb_impl0.index() != FIDESlib::U64) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unsupported source limb type in gpu_matrix_decompose_base"); + } + + const uint32_t src_bits = bit_width_u64(src->ctx->moduli[static_cast(src_limb)]); + cudaError_t err = cudaSetDevice(src_partition0.device); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + + for (uint32_t digit_idx = 0; digit_idx < digits_per_tower; ++digit_idx) + { + const uint32_t shift = digit_idx * base_bits; + uint64_t mask = 0; + if (shift < src_bits) + { + const uint32_t remaining = src_bits - shift; + const uint32_t digit_bits = std::min(base_bits, remaining); + mask = digit_bits >= 64 ? std::numeric_limits::max() + : ((uint64_t{1} << digit_bits) - 1); + } + + const size_t digit_offset = + static_cast(src_limb) * static_cast(digits_per_tower) + + static_cast(digit_idx); + std::vector src_ptrs; + src_ptrs.reserve(count); + for (size_t idx = 0; idx < count; ++idx) + { + const auto &in_partition = inputs[idx]->poly->GPU[src_limb_id.x]; + if (src_limb_id.y >= in_partition.limb.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected input source limb index in gpu_matrix_decompose_base"); + } + const auto &in_limb_impl = in_partition.limb[src_limb_id.y]; + if (in_limb_impl.index() != FIDESlib::U64) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unsupported input source limb type in gpu_matrix_decompose_base"); + } + const auto &in_limb_u64 = std::get(in_limb_impl); + src_ptrs.push_back(in_limb_u64.v.data); + } + + for (int out_limb = 0; out_limb <= level; ++out_limb) + { + const dim3 out_limb_id = limb_map[static_cast(out_limb)]; + std::vector dst_ptrs; + dst_ptrs.reserve(count); + cudaStream_t out_stream = nullptr; + + for (size_t idx = 0; idx < count; ++idx) + { + const auto &in_partition = inputs[idx]->poly->GPU[src_limb_id.x]; + const size_t row = idx / cols; + const size_t col = idx % cols; + const size_t out_row = row * log_base_q + digit_offset; + const size_t out_idx = matrix_index(out_row, col, out->cols); + auto &out_partition = out->polys[out_idx]->poly->GPU[out_limb_id.x]; + if (out_partition.device != in_partition.device) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("input/output limb device mismatch in gpu_matrix_decompose_base"); + } + if (out_limb_id.y >= out_partition.limb.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected output limb index in gpu_matrix_decompose_base"); + } + const auto &out_limb_impl = out_partition.limb[out_limb_id.y]; + if (out_limb_impl.index() != FIDESlib::U64) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unsupported output limb type in gpu_matrix_decompose_base"); + } + const auto &out_limb_u64 = std::get(out_limb_impl); + if (!out_stream) + { + out_stream = out_limb_u64.stream.ptr; + } + dst_ptrs.push_back(out_limb_u64.v.data); + } + + int status = launch_decompose_kernel( + src_ptrs, + dst_ptrs, + static_cast(src->ctx->N), + shift, + mask, + src->ctx->moduli[static_cast(out_limb)], + out_stream); + if (status != 0) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return status; + } + } + } } - if (src->ctx != mat->ctx || src->level != mat->level) + + for (auto *poly : out->polys) { - return set_error("context mismatch in gpu_matrix_copy_entry"); + int status = gpu_poly_ntt(poly, batch); + if (status != 0) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return status; + } + poly->format = PolyFormat::Eval; } - const size_t idx = matrix_index(row, col, mat->cols); - return gpu_poly_copy(mat->polys[idx], src); + out->format = PolyFormat::Eval; + + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return 0; } -extern "C" int gpu_matrix_load_rns_batch( - GpuMatrix *mat, - const uint8_t *bytes, - size_t bytes_per_poly, - int format) +extern "C" int gpu_matrix_gauss_samp_gq_arb_base( + const GpuMatrix *src, + uint32_t base_bits, + double c, + double dgg_stddev, + uint64_t seed, + GpuMatrix *out) { - if (!mat) + (void)dgg_stddev; + if (!src || !out) { - return set_error("invalid gpu_matrix_load_rns_batch arguments"); + return set_error("invalid gpu_matrix_gauss_samp_gq_arb_base arguments"); } - const size_t count = mat->rows * mat->cols; - int status = gpu_poly_load_rns_batch( - mat->polys.data(), - count, - bytes, - bytes_per_poly, - format); - if (status != 0) + if (base_bits == 0 || base_bits >= 63) { - return status; + return set_error("invalid base_bits in gpu_matrix_gauss_samp_gq_arb_base"); + } + if (!(c > 0.0)) + { + return set_error("c must be positive in gpu_matrix_gauss_samp_gq_arb_base"); + } + if (src->ctx != out->ctx || src->level != out->level) + { + return set_error("context mismatch in gpu_matrix_gauss_samp_gq_arb_base"); + } + + const size_t rows = src->rows; + const size_t cols = src->cols; + const size_t count = rows * cols; + const int level = src->level; + if (level < 0) + { + return set_error("invalid level in gpu_matrix_gauss_samp_gq_arb_base"); + } + const size_t crt_depth = static_cast(level + 1); + uint32_t crt_bits = 0; + for (const auto &modulus : src->ctx->moduli) + { + crt_bits = std::max(crt_bits, bit_width_u64(modulus)); + } + if (crt_bits == 0) + { + return set_error("invalid crt_bits in gpu_matrix_gauss_samp_gq_arb_base"); + } + const uint32_t digits_per_tower = static_cast((crt_bits + base_bits - 1) / base_bits); + if (digits_per_tower == 0 || digits_per_tower > kGaussMaxDigits) + { + return set_error("invalid digits_per_tower in gpu_matrix_gauss_samp_gq_arb_base"); + } + const size_t log_base_q = static_cast(digits_per_tower) * crt_depth; + if (out->rows != rows * log_base_q || out->cols != cols) + { + return set_error("output size mismatch in gpu_matrix_gauss_samp_gq_arb_base"); + } + if (count == 0) + { + out->format = PolyFormat::Eval; + return 0; + } + + std::vector tmp_inputs; + std::vector inputs; + inputs.reserve(count); + const int batch = default_batch(src->ctx); + if (src->format == PolyFormat::Eval) + { + tmp_inputs.reserve(count); + for (size_t i = 0; i < count; ++i) + { + int sync_status = sync_poly_partition_streams( + src->polys[i], + "failed to synchronize source partition stream in gpu_matrix_gauss_samp_gq_arb_base"); + if (sync_status != 0) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return sync_status; + } + sync_status = sync_poly_limb_streams( + src->polys[i], + "failed to synchronize source limb stream in gpu_matrix_gauss_samp_gq_arb_base"); + if (sync_status != 0) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return sync_status; + } + + GpuPoly *clone = nullptr; + int status = gpu_poly_clone(src->polys[i], &clone); + if (status != 0) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return status; + } + sync_status = sync_poly_partition_streams( + clone, + "failed to synchronize clone partition stream in gpu_matrix_gauss_samp_gq_arb_base"); + if (sync_status != 0) + { + gpu_poly_destroy(clone); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return sync_status; + } + sync_status = sync_poly_limb_streams( + clone, + "failed to synchronize clone limb stream in gpu_matrix_gauss_samp_gq_arb_base"); + if (sync_status != 0) + { + gpu_poly_destroy(clone); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return sync_status; + } + status = gpu_poly_intt(clone, batch); + if (status != 0) + { + gpu_poly_destroy(clone); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return status; + } + tmp_inputs.push_back(clone); + inputs.push_back(clone); + } } - PolyFormat fmt; - if (!parse_format(format, fmt)) + else { - return set_error("invalid format in gpu_matrix_load_rns_batch"); + for (size_t i = 0; i < count; ++i) + { + inputs.push_back(src->polys[i]); + } } - mat->format = fmt; - return 0; -} -extern "C" int gpu_matrix_store_rns_batch( - const GpuMatrix *mat, - uint8_t *bytes_out, - size_t bytes_per_poly, - int format, - GpuEventSet **out_events) -{ - if (!mat) + auto &limb_map = src->ctx->ctx->limbGPUid; + if (limb_map.size() < crt_depth) { - return set_error("invalid gpu_matrix_store_rns_batch arguments"); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected limb mapping size in gpu_matrix_gauss_samp_gq_arb_base"); } - const size_t count = mat->rows * mat->cols; - return gpu_poly_store_rns_batch( - const_cast(mat->polys.data()), - count, - bytes_out, - bytes_per_poly, - format, - out_events); -} -extern "C" int gpu_matrix_add(GpuMatrix *out, const GpuMatrix *lhs, const GpuMatrix *rhs) -{ - if (!out || !lhs || !rhs) - { - return set_error("invalid gpu_matrix_add arguments"); - } - if (lhs->rows != rhs->rows || lhs->cols != rhs->cols) - { - return set_error("size mismatch in gpu_matrix_add"); - } - if (out->rows != lhs->rows || out->cols != lhs->cols) - { - return set_error("output size mismatch in gpu_matrix_add"); - } - if (lhs->ctx != rhs->ctx || lhs->ctx != out->ctx || lhs->level != rhs->level || - lhs->level != out->level) - { - return set_error("context mismatch in gpu_matrix_add"); - } - const size_t count = lhs->rows * lhs->cols; - int status = gpu_block_add( - out->polys.data(), - const_cast(lhs->polys.data()), - const_cast(rhs->polys.data()), - count); - if (status != 0) + for (size_t idx = 0; idx < out->polys.size(); ++idx) { - return status; - } - out->format = PolyFormat::Eval; - return 0; -} + GpuPoly *poly = out->polys[idx]; + if (!poly || poly->ctx != src->ctx || poly->level != level) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("invalid output poly in gpu_matrix_gauss_samp_gq_arb_base"); + } -extern "C" int gpu_matrix_sub(GpuMatrix *out, const GpuMatrix *lhs, const GpuMatrix *rhs) -{ - if (!out || !lhs || !rhs) - { - return set_error("invalid gpu_matrix_sub arguments"); - } - if (lhs->rows != rhs->rows || lhs->cols != rhs->cols) - { - return set_error("size mismatch in gpu_matrix_sub"); - } - if (out->rows != lhs->rows || out->cols != lhs->cols) - { - return set_error("output size mismatch in gpu_matrix_sub"); + for (int limb = 0; limb <= level; ++limb) + { + const dim3 limb_id = limb_map[static_cast(limb)]; + if (limb_id.x >= poly->poly->GPU.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected limb GPU partition in gpu_matrix_gauss_samp_gq_arb_base"); + } + auto &partition = poly->poly->GPU[limb_id.x]; + if (limb_id.y >= partition.limb.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected limb index in gpu_matrix_gauss_samp_gq_arb_base"); + } + auto &limb_impl = partition.limb[limb_id.y]; + if (limb_impl.index() != FIDESlib::U64) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unsupported limb type in gpu_matrix_gauss_samp_gq_arb_base"); + } + auto &limb_u64 = std::get(limb_impl); + + cudaError_t err = cudaSetDevice(partition.device); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + err = cudaMemsetAsync( + limb_u64.v.data, + 0, + static_cast(src->ctx->N) * sizeof(uint64_t), + limb_u64.stream.ptr); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + } + poly->format = PolyFormat::Coeff; + poly->level = level; } - if (lhs->ctx != rhs->ctx || lhs->ctx != out->ctx || lhs->level != rhs->level || - lhs->level != out->level) + + if (src->ctx->moduli.size() < crt_depth) { - return set_error("context mismatch in gpu_matrix_sub"); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected modulus count in gpu_matrix_gauss_samp_gq_arb_base"); } - const size_t count = lhs->rows * lhs->cols; - int status = gpu_block_sub( - out->polys.data(), - const_cast(lhs->polys.data()), - const_cast(rhs->polys.data()), - count); - if (status != 0) + + for (int src_limb = 0; src_limb <= level; ++src_limb) { - return status; - } - out->format = lhs->format; - return 0; -} + const dim3 src_limb_id = limb_map[static_cast(src_limb)]; + if (src_limb_id.x >= inputs[0]->poly->GPU.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected source limb GPU partition in gpu_matrix_gauss_samp_gq_arb_base"); + } + const auto &src_partition0 = inputs[0]->poly->GPU[src_limb_id.x]; + if (src_limb_id.y >= src_partition0.limb.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected source limb index in gpu_matrix_gauss_samp_gq_arb_base"); + } + const auto &src_limb_impl0 = src_partition0.limb[src_limb_id.y]; + if (src_limb_impl0.index() != FIDESlib::U64) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unsupported source limb type in gpu_matrix_gauss_samp_gq_arb_base"); + } + + for (uint32_t digit_idx = 0; digit_idx < digits_per_tower; ++digit_idx) + { + const size_t digit_offset = + static_cast(src_limb) * static_cast(digits_per_tower) + + static_cast(digit_idx); + std::vector src_ptrs; + src_ptrs.reserve(count); + std::vector src_streams; + src_streams.reserve(count); + for (size_t idx = 0; idx < count; ++idx) + { + const auto &in_partition = inputs[idx]->poly->GPU[src_limb_id.x]; + if (src_limb_id.y >= in_partition.limb.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected input source limb index in gpu_matrix_gauss_samp_gq_arb_base"); + } + const auto &in_limb_impl = in_partition.limb[src_limb_id.y]; + if (in_limb_impl.index() != FIDESlib::U64) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unsupported input source limb type in gpu_matrix_gauss_samp_gq_arb_base"); + } + const auto &in_limb_u64 = std::get(in_limb_impl); + cudaStream_t src_stream = in_limb_u64.stream.ptr; + bool seen_src_stream = false; + for (cudaStream_t s : src_streams) + { + if (s == src_stream) + { + seen_src_stream = true; + break; + } + } + if (!seen_src_stream) + { + src_streams.push_back(src_stream); + } + cudaStream_t src_partition_stream = in_partition.s.ptr; + seen_src_stream = false; + for (cudaStream_t s : src_streams) + { + if (s == src_partition_stream) + { + seen_src_stream = true; + break; + } + } + if (!seen_src_stream) + { + src_streams.push_back(src_partition_stream); + } + src_ptrs.push_back(in_limb_u64.v.data); + } + + for (int out_limb = 0; out_limb <= level; ++out_limb) + { + const dim3 out_limb_id = limb_map[static_cast(out_limb)]; + std::vector dst_ptrs; + dst_ptrs.reserve(count); + cudaStream_t out_stream = nullptr; + std::vector dst_streams; + dst_streams.reserve(count); + + for (size_t idx = 0; idx < count; ++idx) + { + const auto &in_partition = inputs[idx]->poly->GPU[src_limb_id.x]; + const size_t row = idx / cols; + const size_t col = idx % cols; + const size_t out_row = row * log_base_q + digit_offset; + const size_t out_idx = matrix_index(out_row, col, out->cols); + auto &out_partition = out->polys[out_idx]->poly->GPU[out_limb_id.x]; + if (out_partition.device != in_partition.device) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("input/output limb device mismatch in gpu_matrix_gauss_samp_gq_arb_base"); + } + if (out_limb_id.y >= out_partition.limb.size()) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unexpected output limb index in gpu_matrix_gauss_samp_gq_arb_base"); + } + const auto &out_limb_impl = out_partition.limb[out_limb_id.y]; + if (out_limb_impl.index() != FIDESlib::U64) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error("unsupported output limb type in gpu_matrix_gauss_samp_gq_arb_base"); + } + const auto &out_limb_u64 = std::get(out_limb_impl); + if (!out_stream) + { + out_stream = out_limb_u64.stream.ptr; + } + cudaStream_t dst_stream = out_limb_u64.stream.ptr; + bool seen_stream = false; + for (cudaStream_t s : dst_streams) + { + if (s == dst_stream) + { + seen_stream = true; + break; + } + } + if (!seen_stream) + { + dst_streams.push_back(dst_stream); + } + cudaStream_t dst_partition_stream = out_partition.s.ptr; + seen_stream = false; + for (cudaStream_t s : dst_streams) + { + if (s == dst_partition_stream) + { + seen_stream = true; + break; + } + } + if (!seen_stream) + { + dst_streams.push_back(dst_partition_stream); + } + dst_ptrs.push_back(out_limb_u64.v.data); + } + + for (cudaStream_t dst_stream : dst_streams) + { + if (dst_stream == out_stream) + { + continue; + } + cudaEvent_t ready = nullptr; + cudaError_t err = cudaEventCreateWithFlags(&ready, cudaEventDisableTiming); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + err = cudaEventRecord(ready, dst_stream); + if (err != cudaSuccess) + { + cudaEventDestroy(ready); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + err = cudaStreamWaitEvent(out_stream, ready, 0); + cudaError_t destroy_err = cudaEventDestroy(ready); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + if (destroy_err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(destroy_err); + } + } + for (cudaStream_t src_stream : src_streams) + { + if (src_stream == out_stream) + { + continue; + } + cudaEvent_t ready = nullptr; + cudaError_t err = cudaEventCreateWithFlags(&ready, cudaEventDisableTiming); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + err = cudaEventRecord(ready, src_stream); + if (err != cudaSuccess) + { + cudaEventDestroy(ready); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + err = cudaStreamWaitEvent(out_stream, ready, 0); + cudaError_t destroy_err = cudaEventDestroy(ready); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + if (destroy_err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(destroy_err); + } + } -extern "C" int gpu_matrix_mul(GpuMatrix *out, const GpuMatrix *lhs, const GpuMatrix *rhs) -{ - if (!out || !lhs || !rhs) - { - return set_error("invalid gpu_matrix_mul arguments"); - } - if (lhs->cols != rhs->rows) - { - return set_error("size mismatch in gpu_matrix_mul"); - } - if (out->rows != lhs->rows || out->cols != rhs->cols) - { - return set_error("output size mismatch in gpu_matrix_mul"); - } - if (lhs->ctx != rhs->ctx || lhs->ctx != out->ctx || lhs->level != rhs->level || - lhs->level != out->level) - { - return set_error("context mismatch in gpu_matrix_mul"); - } - int status = gpu_block_mul( - out->polys.data(), - const_cast(lhs->polys.data()), - const_cast(rhs->polys.data()), - lhs->rows, - lhs->cols, - rhs->cols); - if (status != 0) - { - return status; + int status = launch_gauss_samp_gq_arb_base_kernel( + src_ptrs, + dst_ptrs, + static_cast(src->ctx->N), + src->ctx->moduli[static_cast(src_limb)], + base_bits, + digits_per_tower, + digit_idx, + c, + static_cast(src_limb), + seed, + src->ctx->moduli[static_cast(out_limb)], + out_stream); + if (status != 0) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return status; + } + cudaEvent_t done = nullptr; + cudaError_t err = cudaEventCreateWithFlags(&done, cudaEventDisableTiming); + if (err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + err = cudaEventRecord(done, out_stream); + if (err != cudaSuccess) + { + cudaEventDestroy(done); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + for (cudaStream_t dst_stream : dst_streams) + { + if (dst_stream == out_stream) + { + continue; + } + err = cudaStreamWaitEvent(dst_stream, done, 0); + if (err != cudaSuccess) + { + cudaEventDestroy(done); + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(err); + } + } + cudaError_t destroy_err = cudaEventDestroy(done); + if (destroy_err != cudaSuccess) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return set_error(destroy_err); + } + } + } } - out->format = PolyFormat::Eval; - return 0; -} -extern "C" int gpu_matrix_mul_timed( - GpuMatrix *out, - const GpuMatrix *lhs, - const GpuMatrix *rhs, - double *out_kernel_ms) -{ - if (!out_kernel_ms) - { - return set_error("null out_kernel_ms in gpu_matrix_mul_timed"); - } - if (!out || !lhs || !rhs) - { - return set_error("invalid gpu_matrix_mul_timed arguments"); - } - if (lhs->cols != rhs->rows) - { - return set_error("size mismatch in gpu_matrix_mul_timed"); - } - if (out->rows != lhs->rows || out->cols != rhs->cols) - { - return set_error("output size mismatch in gpu_matrix_mul_timed"); - } - if (lhs->ctx != rhs->ctx || lhs->ctx != out->ctx || lhs->level != rhs->level || - lhs->level != out->level) - { - return set_error("context mismatch in gpu_matrix_mul_timed"); - } - int status = gpu_block_mul_timed( - out->polys.data(), - const_cast(lhs->polys.data()), - const_cast(rhs->polys.data()), - lhs->rows, - lhs->cols, - rhs->cols, - out_kernel_ms); - if (status != 0) + for (auto *poly : out->polys) { - return status; + int status = gpu_poly_ntt(poly, batch); + if (status != 0) + { + for (auto *p : tmp_inputs) + { + gpu_poly_destroy(p); + } + return status; + } + poly->format = PolyFormat::Eval; } out->format = PolyFormat::Eval; - return 0; -} - -extern "C" int gpu_matrix_mul_scalar( - GpuMatrix *out, - const GpuMatrix *lhs, - const GpuPoly *scalar) -{ - if (!out || !lhs || !scalar) - { - return set_error("invalid gpu_matrix_mul_scalar arguments"); - } - if (out->rows != lhs->rows || out->cols != lhs->cols) - { - return set_error("output size mismatch in gpu_matrix_mul_scalar"); - } - if (lhs->ctx != out->ctx || lhs->level != out->level) - { - return set_error("context mismatch in gpu_matrix_mul_scalar"); - } - if (scalar->ctx != lhs->ctx || scalar->level != lhs->level) - { - return set_error("scalar context mismatch in gpu_matrix_mul_scalar"); - } - const size_t count = lhs->rows * lhs->cols; - std::vector rhs(count, scalar); - int status = gpu_block_entrywise_mul( - out->polys.data(), - const_cast(lhs->polys.data()), - rhs.data(), - count); - if (status != 0) + for (auto *p : tmp_inputs) { - return status; + gpu_poly_destroy(p); } - out->format = lhs->format; return 0; } -extern "C" int gpu_matrix_copy_block( - GpuMatrix *out, - const GpuMatrix *src, - size_t dst_row, - size_t dst_col, - size_t src_row, - size_t src_col, - size_t rows, - size_t cols) +extern "C" int gpu_matrix_sample_p1_full( + const GpuMatrix *a_mat, + const GpuMatrix *b_mat, + const GpuMatrix *d_mat, + const GpuMatrix *tp2, + double sigma, + double s, + double dgg_stddev, + uint64_t seed, + GpuMatrix *out) { - if (!out || !src) + if (!a_mat || !b_mat || !d_mat || !tp2 || !out) { - return set_error("invalid gpu_matrix_copy_block arguments"); + return set_error("invalid gpu_matrix_sample_p1_full arguments"); } - if (src_row + rows > src->rows || src_col + cols > src->cols) + if (!(sigma > 0.0) || !(s > sigma)) { - return set_error("source bounds exceeded in gpu_matrix_copy_block"); + return set_error("invalid sigma/s in gpu_matrix_sample_p1_full"); } - if (dst_row + rows > out->rows || dst_col + cols > out->cols) + if (!(dgg_stddev > 0.0)) { - return set_error("dest bounds exceeded in gpu_matrix_copy_block"); + return set_error("dgg_stddev must be positive in gpu_matrix_sample_p1_full"); } - if (src->ctx != out->ctx || src->level != out->level) + if (a_mat->ctx != b_mat->ctx || a_mat->ctx != d_mat->ctx || a_mat->ctx != tp2->ctx || a_mat->ctx != out->ctx) { - return set_error("context mismatch in gpu_matrix_copy_block"); + return set_error("context mismatch in gpu_matrix_sample_p1_full"); } - - for (size_t i = 0; i < rows; ++i) + if (a_mat->level != b_mat->level || a_mat->level != d_mat->level || + a_mat->level != tp2->level || a_mat->level != out->level) { - for (size_t j = 0; j < cols; ++j) - { - const size_t src_idx = matrix_index(src_row + i, src_col + j, src->cols); - const size_t dst_idx = matrix_index(dst_row + i, dst_col + j, out->cols); - int status = gpu_poly_copy(out->polys[dst_idx], src->polys[src_idx]); - if (status != 0) - { - return status; - } - } + return set_error("level mismatch in gpu_matrix_sample_p1_full"); } - out->format = src->format; - return 0; -} -extern "C" int gpu_matrix_decompose_base(const GpuMatrix *src, uint32_t base_bits, GpuMatrix *out) -{ - if (!src || !out) + const size_t d_rows = a_mat->rows; + if (a_mat->cols != d_rows || b_mat->rows != d_rows || b_mat->cols != d_rows || + d_mat->rows != d_rows || d_mat->cols != d_rows) { - return set_error("invalid gpu_matrix_decompose_base arguments"); + return set_error("A/B/D must be dxd in gpu_matrix_sample_p1_full"); } - if (base_bits == 0) + const size_t cols = tp2->cols; + if (tp2->rows != 2 * d_rows || out->rows != 2 * d_rows || out->cols != cols) { - return set_error("base_bits must be non-zero in gpu_matrix_decompose_base"); + return set_error("tp2/out shape mismatch in gpu_matrix_sample_p1_full"); } - if (src->ctx != out->ctx || src->level != out->level) + if (cols == 0 || d_rows == 0) { - return set_error("context mismatch in gpu_matrix_decompose_base"); + out->format = PolyFormat::Eval; + return 0; } - const size_t rows = src->rows; - const size_t cols = src->cols; - const size_t count = rows * cols; - const int level = src->level; + const int level = a_mat->level; if (level < 0) { - return set_error("invalid level in gpu_matrix_decompose_base"); + return set_error("invalid level in gpu_matrix_sample_p1_full"); } const size_t crt_depth = static_cast(level + 1); - uint32_t crt_bits = 0; - for (const auto &modulus : src->ctx->moduli) - { - crt_bits = std::max(crt_bits, bit_width_u64(modulus)); - } - if (crt_bits == 0) - { - return set_error("invalid crt_bits in gpu_matrix_decompose_base"); - } - const uint32_t digits_per_tower = - static_cast((crt_bits + base_bits - 1) / base_bits); - if (digits_per_tower == 0) - { - return set_error("invalid digits_per_tower in gpu_matrix_decompose_base"); - } - const size_t log_base_q = static_cast(digits_per_tower) * crt_depth; - if (out->rows != rows * log_base_q || out->cols != cols) + if (a_mat->ctx->moduli.size() < crt_depth) { - return set_error("output size mismatch in gpu_matrix_decompose_base"); + return set_error("unexpected modulus count in gpu_matrix_sample_p1_full"); } - if (count == 0) + auto &limb_map = a_mat->ctx->ctx->limbGPUid; + if (limb_map.size() < crt_depth) { - out->format = PolyFormat::Eval; - return 0; + return set_error("unexpected limb mapping size in gpu_matrix_sample_p1_full"); } - std::vector tmp_inputs; - std::vector inputs; - inputs.reserve(count); - const int batch = default_batch(src->ctx); - if (src->format == PolyFormat::Eval) - { - tmp_inputs.reserve(count); - for (size_t i = 0; i < count; ++i) + std::vector tmp_a; + std::vector tmp_b; + std::vector tmp_d; + std::vector tmp_tp2; + std::vector a_inputs; + std::vector b_inputs; + std::vector d_inputs; + std::vector tp2_inputs; + + auto cleanup = [&]() { + for (auto *p : tmp_a) { - GpuPoly *clone = nullptr; - int status = gpu_poly_clone(src->polys[i], &clone); - if (status != 0) + gpu_poly_destroy(p); + } + for (auto *p : tmp_b) + { + gpu_poly_destroy(p); + } + for (auto *p : tmp_d) + { + gpu_poly_destroy(p); + } + for (auto *p : tmp_tp2) + { + gpu_poly_destroy(p); + } + }; + + auto collect_coeff_inputs = [&](const GpuMatrix *src, std::vector &owned, std::vector &inputs) -> int { + const size_t count = src->rows * src->cols; + inputs.clear(); + inputs.reserve(count); + const int batch = default_batch(src->ctx); + if (src->format == PolyFormat::Eval) + { + owned.reserve(count); + for (size_t i = 0; i < count; ++i) { - for (auto *p : tmp_inputs) + GpuPoly *clone = nullptr; + int status = gpu_poly_clone(src->polys[i], &clone); + if (status != 0) { - gpu_poly_destroy(p); + return status; } - return status; - } - status = gpu_poly_intt(clone, batch); - if (status != 0) - { - gpu_poly_destroy(clone); - for (auto *p : tmp_inputs) + status = gpu_poly_intt(clone, batch); + if (status != 0) { - gpu_poly_destroy(p); + gpu_poly_destroy(clone); + return status; } - return status; + owned.push_back(clone); + inputs.push_back(clone); } - tmp_inputs.push_back(clone); - inputs.push_back(clone); } - } - else - { - for (size_t i = 0; i < count; ++i) + else { - inputs.push_back(src->polys[i]); + for (size_t i = 0; i < count; ++i) + { + inputs.push_back(src->polys[i]); + } } + return 0; + }; + + int status = collect_coeff_inputs(a_mat, tmp_a, a_inputs); + if (status != 0) + { + cleanup(); + return status; + } + status = collect_coeff_inputs(b_mat, tmp_b, b_inputs); + if (status != 0) + { + cleanup(); + return status; + } + status = collect_coeff_inputs(d_mat, tmp_d, d_inputs); + if (status != 0) + { + cleanup(); + return status; + } + status = collect_coeff_inputs(tp2, tmp_tp2, tp2_inputs); + if (status != 0) + { + cleanup(); + return status; } - auto &limb_map = src->ctx->ctx->limbGPUid; - if (limb_map.size() < crt_depth) + // Ensure all pending INTT work has completed before cross-stream reads. + for (int device : a_mat->ctx->gpu_ids) { - for (auto *p : tmp_inputs) + cudaError_t err = cudaSetDevice(device); + if (err != cudaSuccess) { - gpu_poly_destroy(p); + cleanup(); + return set_error(err); + } + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) + { + cleanup(); + return set_error(err); } - return set_error("unexpected limb mapping size in gpu_matrix_decompose_base"); } for (size_t idx = 0; idx < out->polys.size(); ++idx) { GpuPoly *poly = out->polys[idx]; - if (!poly || poly->ctx != src->ctx || poly->level != level) + if (!poly || poly->ctx != a_mat->ctx || poly->level != level) { - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } - return set_error("invalid output poly in gpu_matrix_decompose_base"); + cleanup(); + return set_error("invalid output poly in gpu_matrix_sample_p1_full"); } + poly->format = PolyFormat::Coeff; + poly->level = level; + } + out->format = PolyFormat::Coeff; - for (int limb = 0; limb <= level; ++limb) + for (int limb = 0; limb <= level; ++limb) + { + const dim3 limb_id = limb_map[static_cast(limb)]; + + std::vector a_entry_ptrs; + std::vector b_entry_ptrs; + std::vector d_entry_ptrs; + std::vector tp2_entry_ptrs; + std::vector out_entry_ptrs; + a_entry_ptrs.reserve(d_rows * d_rows); + b_entry_ptrs.reserve(d_rows * d_rows); + d_entry_ptrs.reserve(d_rows * d_rows); + tp2_entry_ptrs.reserve(2 * d_rows * cols); + out_entry_ptrs.reserve(2 * d_rows * cols); + + cudaStream_t out_stream = nullptr; + int out_device = -1; + for (size_t i = 0; i < d_rows; ++i) { - const dim3 limb_id = limb_map[static_cast(limb)]; - if (limb_id.x >= poly->poly->GPU.size()) + for (size_t j = 0; j < d_rows; ++j) { - for (auto *p : tmp_inputs) + const size_t idx = matrix_index(i, j, d_rows); + if (limb_id.x >= a_inputs[idx]->poly->GPU.size() || + limb_id.x >= b_inputs[idx]->poly->GPU.size() || + limb_id.x >= d_inputs[idx]->poly->GPU.size()) { - gpu_poly_destroy(p); + cleanup(); + return set_error("unexpected A/B/D limb GPU partition in gpu_matrix_sample_p1_full"); } - return set_error("unexpected limb GPU partition in gpu_matrix_decompose_base"); - } - auto &partition = poly->poly->GPU[limb_id.x]; - if (limb_id.y >= partition.limb.size()) - { - for (auto *p : tmp_inputs) + const auto &a_part = a_inputs[idx]->poly->GPU[limb_id.x]; + const auto &b_part = b_inputs[idx]->poly->GPU[limb_id.x]; + const auto &d_part = d_inputs[idx]->poly->GPU[limb_id.x]; + if (limb_id.y >= a_part.limb.size() || + limb_id.y >= b_part.limb.size() || + limb_id.y >= d_part.limb.size()) { - gpu_poly_destroy(p); + cleanup(); + return set_error("unexpected A/B/D limb index in gpu_matrix_sample_p1_full"); } - return set_error("unexpected limb index in gpu_matrix_decompose_base"); - } - auto &limb_impl = partition.limb[limb_id.y]; - if (limb_impl.index() != FIDESlib::U64) - { - for (auto *p : tmp_inputs) + const auto &a_impl = a_part.limb[limb_id.y]; + const auto &b_impl = b_part.limb[limb_id.y]; + const auto &d_impl = d_part.limb[limb_id.y]; + if (a_impl.index() != FIDESlib::U64 || + b_impl.index() != FIDESlib::U64 || + d_impl.index() != FIDESlib::U64) { - gpu_poly_destroy(p); + cleanup(); + return set_error("unsupported A/B/D limb type in gpu_matrix_sample_p1_full"); } - return set_error("unsupported limb type in gpu_matrix_decompose_base"); + a_entry_ptrs.push_back(std::get(a_impl).v.data); + b_entry_ptrs.push_back(std::get(b_impl).v.data); + d_entry_ptrs.push_back(std::get(d_impl).v.data); } - auto &limb_u64 = std::get(limb_impl); - - cudaError_t err = cudaSetDevice(partition.device); - if (err != cudaSuccess) + } + for (size_t row = 0; row < 2 * d_rows; ++row) + { + for (size_t col = 0; col < cols; ++col) { - for (auto *p : tmp_inputs) + const size_t idx = matrix_index(row, col, cols); + if (limb_id.x >= tp2_inputs[idx]->poly->GPU.size() || + limb_id.x >= out->polys[idx]->poly->GPU.size()) { - gpu_poly_destroy(p); + cleanup(); + return set_error("unexpected tp2/output limb GPU partition in gpu_matrix_sample_p1_full"); } - return set_error(err); - } - err = cudaMemsetAsync( - limb_u64.v.data, - 0, - static_cast(src->ctx->N) * sizeof(uint64_t), - limb_u64.stream.ptr); - if (err != cudaSuccess) - { - for (auto *p : tmp_inputs) + const auto &tp2_part = tp2_inputs[idx]->poly->GPU[limb_id.x]; + auto &out_part = out->polys[idx]->poly->GPU[limb_id.x]; + if (tp2_part.device != out_part.device) { - gpu_poly_destroy(p); + cleanup(); + return set_error("input/output limb device mismatch in gpu_matrix_sample_p1_full"); } - return set_error(err); + if (out_device < 0) + { + out_device = out_part.device; + } + else if (out_device != out_part.device) + { + cleanup(); + return set_error("mixed output devices in gpu_matrix_sample_p1_full"); + } + if (limb_id.y >= tp2_part.limb.size() || limb_id.y >= out_part.limb.size()) + { + cleanup(); + return set_error("unexpected tp2/output limb index in gpu_matrix_sample_p1_full"); + } + const auto &tp2_impl = tp2_part.limb[limb_id.y]; + auto &out_impl = out_part.limb[limb_id.y]; + if (tp2_impl.index() != FIDESlib::U64 || out_impl.index() != FIDESlib::U64) + { + cleanup(); + return set_error("unsupported tp2/output limb type in gpu_matrix_sample_p1_full"); + } + const auto &tp2_u64 = std::get(tp2_impl); + auto &out_u64 = std::get(out_impl); + if (!out_stream) + { + out_stream = out_u64.stream.ptr; + } + tp2_entry_ptrs.push_back(tp2_u64.v.data); + out_entry_ptrs.push_back(out_u64.v.data); } } - poly->format = PolyFormat::Coeff; - poly->level = level; + + status = launch_sample_p1_full_kernel( + a_entry_ptrs, + b_entry_ptrs, + d_entry_ptrs, + tp2_entry_ptrs, + out_entry_ptrs, + d_rows, + cols, + static_cast(a_mat->ctx->N), + a_mat->ctx->moduli[static_cast(limb)], + sigma, + s, + dgg_stddev, + static_cast(limb), + seed, + out_stream, + out_device); + if (status != 0) + { + cleanup(); + return status; + } + } + + const int batch = default_batch(out->ctx); + for (auto *poly : out->polys) + { + status = gpu_poly_ntt(poly, batch); + if (status != 0) + { + cleanup(); + return status; + } + poly->format = PolyFormat::Eval; + } + out->format = PolyFormat::Eval; + + cleanup(); + return 0; +} + +extern "C" int gpu_matrix_sample_distribution( + GpuMatrix *out, + int dist_type, + double sigma, + uint64_t seed) +{ + if (!out) + { + return set_error("invalid gpu_matrix_sample_distribution arguments"); + } + if (dist_type < GPU_MATRIX_DIST_UNIFORM || dist_type > GPU_MATRIX_DIST_TERNARY) + { + return set_error("invalid dist_type in gpu_matrix_sample_distribution"); + } + if (dist_type == GPU_MATRIX_DIST_GAUSS && !(sigma > 0.0)) + { + return set_error("sigma must be positive in gpu_matrix_sample_distribution"); + } + + const size_t count = out->rows * out->cols; + if (count == 0) + { + out->format = PolyFormat::Eval; + return 0; } - const uint64_t base_mask = - base_bits >= 64 ? std::numeric_limits::max() : ((1ULL << base_bits) - 1); - const uint32_t last_bits = - static_cast(crt_bits - base_bits * (digits_per_tower - 1)); - const uint64_t last_mask = - last_bits >= 64 ? std::numeric_limits::max() : ((1ULL << last_bits) - 1); + const int level = out->level; + if (level < 0) + { + return set_error("invalid level in gpu_matrix_sample_distribution"); + } + if (out->ctx->moduli.size() < static_cast(level + 1)) + { + return set_error("unexpected modulus count in gpu_matrix_sample_distribution"); + } + + auto &limb_map = out->ctx->ctx->limbGPUid; + if (limb_map.size() < static_cast(level + 1)) + { + return set_error("unexpected limb mapping size in gpu_matrix_sample_distribution"); + } for (int limb = 0; limb <= level; ++limb) { const dim3 limb_id = limb_map[static_cast(limb)]; - if (limb_id.x >= inputs[0]->poly->GPU.size()) + std::vector dst_ptrs; + dst_ptrs.reserve(count); + cudaStream_t out_stream = nullptr; + + for (size_t idx = 0; idx < count; ++idx) { - for (auto *p : tmp_inputs) + GpuPoly *poly = out->polys[idx]; + if (!poly || poly->ctx != out->ctx || poly->level != level) { - gpu_poly_destroy(p); + return set_error("invalid output poly in gpu_matrix_sample_distribution"); } - return set_error("unexpected limb GPU partition in gpu_matrix_decompose_base"); - } - const auto &partition = inputs[0]->poly->GPU[limb_id.x]; - if (limb_id.y >= partition.limb.size()) - { - for (auto *p : tmp_inputs) + if (limb_id.x >= poly->poly->GPU.size()) { - gpu_poly_destroy(p); + return set_error("unexpected limb GPU partition in gpu_matrix_sample_distribution"); } - return set_error("unexpected limb index in gpu_matrix_decompose_base"); - } - const auto &limb_impl = partition.limb[limb_id.y]; - if (limb_impl.index() != FIDESlib::U64) - { - for (auto *p : tmp_inputs) + auto &partition = poly->poly->GPU[limb_id.x]; + if (limb_id.y >= partition.limb.size()) { - gpu_poly_destroy(p); + return set_error("unexpected limb index in gpu_matrix_sample_distribution"); } - return set_error("unsupported limb type in gpu_matrix_decompose_base"); - } - - cudaError_t err = cudaSetDevice(partition.device); - if (err != cudaSuccess) - { - for (auto *p : tmp_inputs) + auto &limb_impl = partition.limb[limb_id.y]; + if (limb_impl.index() != FIDESlib::U64) { - gpu_poly_destroy(p); + return set_error("unsupported limb type in gpu_matrix_sample_distribution"); } - return set_error(err); - } - - for (uint32_t digit_idx = 0; digit_idx < digits_per_tower; ++digit_idx) - { - const size_t digit_offset = - static_cast(limb) * static_cast(digits_per_tower) + - static_cast(digit_idx); - - std::vector src_ptrs; - std::vector dst_ptrs; - src_ptrs.reserve(count); - dst_ptrs.reserve(count); - - cudaStream_t stream = nullptr; - for (size_t idx = 0; idx < count; ++idx) + auto &limb_u64 = std::get(limb_impl); + if (!out_stream) { - const auto &in_partition = inputs[idx]->poly->GPU[limb_id.x]; - if (limb_id.y >= in_partition.limb.size()) - { - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } - return set_error("unexpected input limb index in gpu_matrix_decompose_base"); - } - const auto &in_limb_impl = in_partition.limb[limb_id.y]; - if (in_limb_impl.index() != FIDESlib::U64) - { - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } - return set_error("unsupported limb type in gpu_matrix_decompose_base"); - } - const auto &in_limb_u64 = std::get(in_limb_impl); - if (!stream) - { - stream = in_limb_u64.stream.ptr; - } - src_ptrs.push_back(in_limb_u64.v.data); - - const size_t row = idx / cols; - const size_t col = idx % cols; - const size_t out_row = row * log_base_q + digit_offset; - const size_t out_idx = matrix_index(out_row, col, out->cols); - auto &out_partition = out->polys[out_idx]->poly->GPU[limb_id.x]; - if (out_partition.device != in_partition.device) - { - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } - return set_error("input/output limb device mismatch in gpu_matrix_decompose_base"); - } - if (limb_id.y >= out_partition.limb.size()) - { - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } - return set_error("unexpected output limb index in gpu_matrix_decompose_base"); - } - const auto &out_limb_impl = out_partition.limb[limb_id.y]; - if (out_limb_impl.index() != FIDESlib::U64) - { - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } - return set_error("unsupported output limb type in gpu_matrix_decompose_base"); - } - const auto &out_limb_u64 = std::get(out_limb_impl); - dst_ptrs.push_back(out_limb_u64.v.data); + out_stream = limb_u64.stream.ptr; } + dst_ptrs.push_back(limb_u64.v.data); + } - const uint64_t mask = - (digit_idx + 1 == digits_per_tower && last_mask != 0) ? last_mask : base_mask; - const uint32_t shift = digit_idx * base_bits; - int status = launch_decompose_kernel( - src_ptrs, - dst_ptrs, - static_cast(src->ctx->N), - shift, - mask, - stream); - if (status != 0) - { - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } - return status; - } + int status = launch_sample_distribution_kernel( + dst_ptrs, + static_cast(out->ctx->N), + out->ctx->moduli[static_cast(limb)], + dist_type, + sigma, + static_cast(limb), + seed, + out_stream); + if (status != 0) + { + return status; } } + const int batch = default_batch(out->ctx); for (auto *poly : out->polys) { + poly->format = PolyFormat::Coeff; int status = gpu_poly_ntt(poly, batch); if (status != 0) { - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } return status; } poly->format = PolyFormat::Eval; } out->format = PolyFormat::Eval; - - for (auto *p : tmp_inputs) - { - gpu_poly_destroy(p); - } return 0; } diff --git a/cuda/GpuPoly.cu b/cuda/GpuPoly.cu index 5a23ae6..ec2fa86 100644 --- a/cuda/GpuPoly.cu +++ b/cuda/GpuPoly.cu @@ -61,11 +61,63 @@ namespace return static_cast(level + 1) * static_cast(N); } + void propagate_partition_stream_to_limbs(CKKS::RNSPoly *poly) + { + for (auto &partition : poly->GPU) + { + cudaError_t err = cudaSetDevice(partition.device); + if (err != cudaSuccess) + { + throw std::runtime_error(cudaGetErrorString(err)); + } + for (auto &limb_impl : partition.limb) + { + cudaStream_t limb_stream = nullptr; + if (limb_impl.index() == FIDESlib::U64) + { + limb_stream = std::get(limb_impl).stream.ptr; + } + else if (limb_impl.index() == FIDESlib::U32) + { + limb_stream = std::get(limb_impl).stream.ptr; + } + if (!limb_stream || limb_stream == partition.s.ptr) + { + continue; + } + + cudaEvent_t ready = nullptr; + err = cudaEventCreateWithFlags(&ready, cudaEventDisableTiming); + if (err != cudaSuccess) + { + throw std::runtime_error(cudaGetErrorString(err)); + } + err = cudaEventRecord(ready, partition.s.ptr); + if (err != cudaSuccess) + { + cudaEventDestroy(ready); + throw std::runtime_error(cudaGetErrorString(err)); + } + err = cudaStreamWaitEvent(limb_stream, ready, 0); + cudaError_t destroy_err = cudaEventDestroy(ready); + if (err != cudaSuccess) + { + throw std::runtime_error(cudaGetErrorString(err)); + } + if (destroy_err != cudaSuccess) + { + throw std::runtime_error(cudaGetErrorString(destroy_err)); + } + } + } + } + void ensure_eval(GpuPoly *poly, int batch) { if (poly->format == PolyFormat::Coeff) { poly->poly->NTT(batch); + propagate_partition_stream_to_limbs(poly->poly); poly->format = PolyFormat::Eval; } } @@ -75,6 +127,7 @@ namespace if (poly->format == PolyFormat::Eval) { poly->poly->INTT(batch); + propagate_partition_stream_to_limbs(poly->poly); poly->format = PolyFormat::Coeff; } } @@ -123,16 +176,108 @@ namespace uint64_t *dst, size_t n, uint32_t shift, - uint64_t mask) + uint64_t mask, + uint64_t out_modulus) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { uint64_t residue = src[idx]; uint64_t digit = shift >= 64 ? 0 : ((residue >> shift) & mask); + if (out_modulus != 0 && digit >= out_modulus) + { + digit %= out_modulus; + } dst[idx] = digit; } } + + __global__ void compare_u64_kernel( + const uint64_t *lhs, + const uint64_t *rhs, + size_t n, + int *out_equal) + { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n && lhs[idx] != rhs[idx]) + { + atomicExch(out_equal, 0); + } + } + + int compare_u64_arrays_on_device( + const uint64_t *lhs, + const uint64_t *rhs, + size_t n, + cudaStream_t lhs_stream, + cudaStream_t rhs_stream, + int device, + bool &is_equal) + { + if (n == 0) + { + is_equal = true; + return 0; + } + cudaError_t err = cudaSetDevice(device); + if (err != cudaSuccess) + { + return set_error(cudaGetErrorString(err)); + } + + err = cudaStreamSynchronize(lhs_stream); + if (err != cudaSuccess) + { + return set_error(cudaGetErrorString(err)); + } + err = cudaStreamSynchronize(rhs_stream); + if (err != cudaSuccess) + { + return set_error(cudaGetErrorString(err)); + } + + int *d_equal = nullptr; + err = cudaMalloc(&d_equal, sizeof(int)); + if (err != cudaSuccess) + { + return set_error(cudaGetErrorString(err)); + } + + int h_equal = 1; + err = cudaMemcpyAsync(d_equal, &h_equal, sizeof(int), cudaMemcpyHostToDevice, lhs_stream); + if (err != cudaSuccess) + { + cudaFree(d_equal); + return set_error(cudaGetErrorString(err)); + } + + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + compare_u64_kernel<<>>(lhs, rhs, n, d_equal); + err = cudaGetLastError(); + if (err != cudaSuccess) + { + cudaFree(d_equal); + return set_error(cudaGetErrorString(err)); + } + + err = cudaMemcpyAsync(&h_equal, d_equal, sizeof(int), cudaMemcpyDeviceToHost, lhs_stream); + if (err != cudaSuccess) + { + cudaFree(d_equal); + return set_error(cudaGetErrorString(err)); + } + + err = cudaStreamSynchronize(lhs_stream); + if (err != cudaSuccess) + { + cudaFree(d_equal); + return set_error(cudaGetErrorString(err)); + } + cudaFree(d_equal); + is_equal = h_equal != 0; + return 0; + } } extern "C" @@ -933,6 +1078,110 @@ extern "C" } } + int gpu_poly_equal(const GpuPoly *lhs, const GpuPoly *rhs, int *out_equal) + { + try + { + if (!lhs || !rhs || !out_equal) + { + return set_error("invalid gpu_poly_equal arguments"); + } + *out_equal = 0; + + if (lhs == rhs) + { + *out_equal = 1; + return 0; + } + if (lhs->ctx != rhs->ctx || lhs->level != rhs->level) + { + return 0; + } + + const int batch = default_batch(lhs->ctx); + std::unique_ptr rhs_tmp; + const CKKS::RNSPoly *lhs_poly = lhs->poly; + const CKKS::RNSPoly *rhs_poly = rhs->poly; + if (lhs->format != rhs->format) + { + rhs_tmp = std::make_unique(rhs->poly->clone()); + convert_format(*rhs_tmp, rhs->format, lhs->format, batch); + rhs_poly = rhs_tmp.get(); + } + + auto *ctx = lhs->ctx->ctx; + const int level = lhs->level; + const int n = lhs->ctx->N; + if (n < 0 || level < 0) + { + return set_error("invalid poly metadata in gpu_poly_equal"); + } + if (ctx->limbGPUid.size() < static_cast(level + 1)) + { + return set_error("unexpected limb mapping size in gpu_poly_equal"); + } + + for (int limb = 0; limb <= level; ++limb) + { + const dim3 limb_id = ctx->limbGPUid[static_cast(limb)]; + if (limb_id.x >= lhs_poly->GPU.size() || limb_id.x >= rhs_poly->GPU.size()) + { + return set_error("unexpected limb GPU partition in gpu_poly_equal"); + } + + auto &lhs_partition = lhs_poly->GPU[limb_id.x]; + auto &rhs_partition = rhs_poly->GPU[limb_id.x]; + if (limb_id.y >= lhs_partition.limb.size() || limb_id.y >= rhs_partition.limb.size()) + { + return set_error("unexpected limb index in gpu_poly_equal"); + } + if (lhs_partition.device != rhs_partition.device) + { + return set_error("device mismatch in gpu_poly_equal"); + } + + auto &lhs_limb = lhs_partition.limb[limb_id.y]; + auto &rhs_limb = rhs_partition.limb[limb_id.y]; + if (lhs_limb.index() != FIDESlib::U64 || rhs_limb.index() != FIDESlib::U64) + { + return set_error("unsupported limb type in gpu_poly_equal"); + } + + auto &lhs_u64 = std::get(lhs_limb); + auto &rhs_u64 = std::get(rhs_limb); + bool limb_equal = true; + int status = compare_u64_arrays_on_device( + lhs_u64.v.data, + rhs_u64.v.data, + static_cast(n), + lhs_u64.stream.ptr, + rhs_u64.stream.ptr, + lhs_partition.device, + limb_equal); + if (status != 0) + { + return status; + } + if (!limb_equal) + { + *out_equal = 0; + return 0; + } + } + + *out_equal = 1; + return 0; + } + catch (const std::exception &e) + { + return set_error(e); + } + catch (...) + { + return set_error("unknown exception in gpu_poly_equal"); + } + } + int gpu_poly_decompose_base( const GpuPoly *src, uint32_t base_bits, @@ -1052,70 +1301,88 @@ extern "C" out->level = level; } - const uint64_t base_mask = - base_bits >= 64 ? std::numeric_limits::max() : ((1ULL << base_bits) - 1); const int threads = 256; const int blocks = static_cast((static_cast(N) + threads - 1) / threads); + if (input->ctx->moduli.size() < crt_depth) + { + return set_error("unexpected modulus count in gpu_poly_decompose_base"); + } - for (int limb = 0; limb <= level; ++limb) + for (int src_limb = 0; src_limb <= level; ++src_limb) { - const dim3 limb_id = ctx->limbGPUid[static_cast(limb)]; - if (limb_id.x >= input->poly->GPU.size()) + const dim3 src_limb_id = ctx->limbGPUid[static_cast(src_limb)]; + if (src_limb_id.x >= input->poly->GPU.size()) { - return set_error("unexpected limb GPU partition in gpu_poly_decompose_base"); + return set_error("unexpected source limb GPU partition in gpu_poly_decompose_base"); } - auto &in_partition = input->poly->GPU[limb_id.x]; - if (limb_id.y >= in_partition.limb.size()) + auto &in_partition = input->poly->GPU[src_limb_id.x]; + if (src_limb_id.y >= in_partition.limb.size()) { - return set_error("unexpected limb index in gpu_poly_decompose_base"); + return set_error("unexpected source limb index in gpu_poly_decompose_base"); } - auto &in_limb_impl = in_partition.limb[limb_id.y]; + auto &in_limb_impl = in_partition.limb[src_limb_id.y]; if (in_limb_impl.index() != FIDESlib::U64) { - return set_error("unsupported limb type in gpu_poly_decompose_base"); + return set_error("unsupported source limb type in gpu_poly_decompose_base"); } auto &in_limb_u64 = std::get(in_limb_impl); + const uint32_t src_bits = + bit_width_u64(input->ctx->moduli[static_cast(src_limb)]); for (uint32_t digit_idx = 0; digit_idx < digits_per_tower; ++digit_idx) { - const size_t out_idx = - static_cast(limb) * static_cast(digits_per_tower) + - static_cast(digit_idx); - GpuPoly *out = out_polys[out_idx]; - auto &out_partition = out->poly->GPU[limb_id.x]; - if (limb_id.y >= out_partition.limb.size()) - { - return set_error("unexpected output limb index in gpu_poly_decompose_base"); - } - auto &out_limb_impl = out_partition.limb[limb_id.y]; - if (out_limb_impl.index() != FIDESlib::U64) - { - return set_error("unsupported output limb type in gpu_poly_decompose_base"); - } - auto &out_limb_u64 = std::get(out_limb_impl); - - if (out_partition.device != in_partition.device) - { - return set_error("input/output limb device mismatch in gpu_poly_decompose_base"); - } - - cudaError_t err = cudaSetDevice(out_partition.device); - if (err != cudaSuccess) + const uint32_t shift = digit_idx * base_bits; + uint64_t mask = 0; + if (shift < src_bits) { - return set_error(cudaGetErrorString(err)); + const uint32_t remaining = src_bits - shift; + const uint32_t digit_bits = std::min(base_bits, remaining); + mask = digit_bits >= 64 ? std::numeric_limits::max() + : ((uint64_t{1} << digit_bits) - 1); } - const uint32_t shift = digit_idx * base_bits; - decompose_base_kernel<<>>( - in_limb_u64.v.data, - out_limb_u64.v.data, - static_cast(N), - shift, - base_mask); - err = cudaGetLastError(); - if (err != cudaSuccess) + const size_t out_idx = + static_cast(src_limb) * static_cast(digits_per_tower) + + static_cast(digit_idx); + GpuPoly *out = out_polys[out_idx]; + for (int out_limb = 0; out_limb <= level; ++out_limb) { - return set_error(cudaGetErrorString(err)); + const dim3 out_limb_id = ctx->limbGPUid[static_cast(out_limb)]; + auto &out_partition = out->poly->GPU[out_limb_id.x]; + if (out_limb_id.y >= out_partition.limb.size()) + { + return set_error("unexpected output limb index in gpu_poly_decompose_base"); + } + auto &out_limb_impl = out_partition.limb[out_limb_id.y]; + if (out_limb_impl.index() != FIDESlib::U64) + { + return set_error("unsupported output limb type in gpu_poly_decompose_base"); + } + auto &out_limb_u64 = std::get(out_limb_impl); + + if (out_partition.device != in_partition.device) + { + return set_error("input/output limb device mismatch in gpu_poly_decompose_base"); + } + + cudaError_t err = cudaSetDevice(out_partition.device); + if (err != cudaSuccess) + { + return set_error(cudaGetErrorString(err)); + } + + decompose_base_kernel<<>>( + in_limb_u64.v.data, + out_limb_u64.v.data, + static_cast(N), + shift, + mask, + input->ctx->moduli[static_cast(out_limb)]); + err = cudaGetLastError(); + if (err != cudaSuccess) + { + return set_error(cudaGetErrorString(err)); + } } } } diff --git a/cuda/GpuPoly.h b/cuda/GpuPoly.h index 4b7d99d..3275107 100644 --- a/cuda/GpuPoly.h +++ b/cuda/GpuPoly.h @@ -17,6 +17,14 @@ typedef enum GpuPolyFormat GPU_POLY_FORMAT_EVAL = 1, } GpuPolyFormat; +typedef enum GpuMatrixSampleDist +{ + GPU_MATRIX_DIST_UNIFORM = 0, + GPU_MATRIX_DIST_GAUSS = 1, + GPU_MATRIX_DIST_BIT = 2, + GPU_MATRIX_DIST_TERNARY = 3, +} GpuMatrixSampleDist; + int gpu_context_create( uint32_t logN, uint32_t L, @@ -72,6 +80,7 @@ void gpu_event_set_destroy(GpuEventSet* events); int gpu_poly_add(GpuPoly* out, const GpuPoly* a, const GpuPoly* b); int gpu_poly_sub(GpuPoly* out, const GpuPoly* a, const GpuPoly* b); int gpu_poly_mul(GpuPoly* out, const GpuPoly* a, const GpuPoly* b); +int gpu_poly_equal(const GpuPoly* lhs, const GpuPoly* rhs, int* out_equal); int gpu_block_add(GpuPoly* const* out, const GpuPoly* const* lhs, const GpuPoly* const* rhs, size_t count); int gpu_block_sub(GpuPoly* const* out, const GpuPoly* const* lhs, const GpuPoly* const* rhs, size_t count); int gpu_block_entrywise_mul( @@ -100,6 +109,7 @@ int gpu_matrix_store_rns_batch( int gpu_matrix_add(GpuMatrix* out, const GpuMatrix* lhs, const GpuMatrix* rhs); int gpu_matrix_sub(GpuMatrix* out, const GpuMatrix* lhs, const GpuMatrix* rhs); int gpu_matrix_mul(GpuMatrix* out, const GpuMatrix* lhs, const GpuMatrix* rhs); +int gpu_matrix_equal(const GpuMatrix* lhs, const GpuMatrix* rhs, int* out_equal); int gpu_matrix_mul_timed(GpuMatrix* out, const GpuMatrix* lhs, const GpuMatrix* rhs, double* out_kernel_ms); int gpu_matrix_mul_scalar(GpuMatrix* out, const GpuMatrix* lhs, const GpuPoly* scalar); int gpu_matrix_copy_block( @@ -111,7 +121,32 @@ int gpu_matrix_copy_block( size_t src_col, size_t rows, size_t cols); +int gpu_matrix_fill_gadget( + GpuMatrix* out, + uint32_t base_bits); int gpu_matrix_decompose_base(const GpuMatrix* src, uint32_t base_bits, GpuMatrix* out); +int gpu_matrix_gauss_samp_gq_arb_base( + const GpuMatrix* src, + uint32_t base_bits, + double c, + double dgg_stddev, + uint64_t seed, + GpuMatrix* out); +int gpu_matrix_sample_p1_full( + const GpuMatrix* a_mat, + const GpuMatrix* b_mat, + const GpuMatrix* d_mat, + const GpuMatrix* tp2, + double sigma, + double s, + double dgg_stddev, + uint64_t seed, + GpuMatrix* out); +int gpu_matrix_sample_distribution( + GpuMatrix* out, + int dist_type, + double sigma, + uint64_t seed); int gpu_poly_ntt(GpuPoly* poly, int batch); int gpu_poly_intt(GpuPoly* poly, int batch); diff --git a/src/matrix/gpu_dcrt_poly.rs b/src/matrix/gpu_dcrt_poly.rs index 5827174..dfe3f7d 100644 --- a/src/matrix/gpu_dcrt_poly.rs +++ b/src/matrix/gpu_dcrt_poly.rs @@ -6,13 +6,16 @@ use crate::{ Poly, PolyParams, dcrt::{ gpu::{ - GPU_POLY_FORMAT_EVAL, GpuDCRTPoly, GpuDCRTPolyParams, GpuEventSetOpaque, - GpuMatrixOpaque, check_status, gpu_event_set_destroy, gpu_event_set_wait, - gpu_matrix_add, gpu_matrix_copy, gpu_matrix_copy_block, gpu_matrix_copy_entry, - gpu_matrix_create, gpu_matrix_decompose_base, gpu_matrix_destroy, - gpu_matrix_entry_clone, gpu_matrix_load_rns_batch, gpu_matrix_mul, - gpu_matrix_mul_scalar, gpu_matrix_mul_timed, gpu_matrix_store_rns_batch, - gpu_matrix_sub, + GPU_MATRIX_DIST_BIT, GPU_MATRIX_DIST_GAUSS, GPU_MATRIX_DIST_TERNARY, + GPU_MATRIX_DIST_UNIFORM, GPU_POLY_FORMAT_EVAL, GpuDCRTPoly, GpuDCRTPolyParams, + GpuEventSetOpaque, GpuMatrixOpaque, check_status, gpu_event_set_destroy, + gpu_event_set_wait, gpu_matrix_add, gpu_matrix_copy, gpu_matrix_copy_block, + gpu_matrix_copy_entry, gpu_matrix_create, gpu_matrix_decompose_base, + gpu_matrix_destroy, gpu_matrix_entry_clone, gpu_matrix_equal, + gpu_matrix_fill_gadget, gpu_matrix_gauss_samp_gq_arb_base, + gpu_matrix_load_rns_batch, gpu_matrix_mul, gpu_matrix_mul_scalar, + gpu_matrix_mul_timed, gpu_matrix_sample_distribution, gpu_matrix_sample_p1_full, + gpu_matrix_store_rns_batch, gpu_matrix_sub, }, params::DCRTPolyParams, poly::DCRTPoly, @@ -40,6 +43,25 @@ pub struct GpuDCRTPolyMatrix { raw: *mut GpuMatrixOpaque, } +#[derive(Clone, Copy, Debug)] +pub(crate) enum GpuMatrixSampleDist { + Uniform, + Gauss, + Bit, + Ternary, +} + +impl GpuMatrixSampleDist { + fn as_ffi(self) -> i32 { + match self { + Self::Uniform => GPU_MATRIX_DIST_UNIFORM, + Self::Gauss => GPU_MATRIX_DIST_GAUSS, + Self::Bit => GPU_MATRIX_DIST_BIT, + Self::Ternary => GPU_MATRIX_DIST_TERNARY, + } + } +} + /// # Safety /// GpuDCRTPolyMatrix owns an opaque GPU handle managed on the C++ side. unsafe impl Send for GpuDCRTPolyMatrix {} @@ -98,9 +120,10 @@ impl PartialEq for GpuDCRTPolyMatrix { if self.raw == other.raw { return true; } - let lhs = self.to_cpu_matrix(); - let rhs = other.to_cpu_matrix(); - lhs == rhs + let mut out_equal: i32 = 0; + let status = unsafe { gpu_matrix_equal(self.raw, other.raw, &mut out_equal as *mut i32) }; + check_status(status, "gpu_matrix_equal"); + out_equal != 0 } } @@ -158,6 +181,73 @@ impl GpuDCRTPolyMatrix { check_status(status, "gpu_matrix_copy_block"); } + pub(crate) fn sample_distribution( + params: &GpuDCRTPolyParams, + nrow: usize, + ncol: usize, + dist: GpuMatrixSampleDist, + sigma: f64, + seed: u64, + ) -> Self { + let out = Self::new_empty(params, nrow, ncol); + if nrow == 0 || ncol == 0 { + return out; + } + let status = unsafe { gpu_matrix_sample_distribution(out.raw, dist.as_ffi(), sigma, seed) }; + check_status(status, "gpu_matrix_sample_distribution"); + out + } + + pub fn gauss_samp_gq_arb_base(&self, c: f64, dgg_stddev: f64, seed: u64) -> Self { + let log_base_q = self.params.modulus_digits(); + let out_nrow = self.nrow.saturating_mul(log_base_q); + let out = Self::new_empty(&self.params, out_nrow, self.ncol); + let status = unsafe { + gpu_matrix_gauss_samp_gq_arb_base( + self.raw, + self.params.base_bits(), + c, + dgg_stddev, + seed, + out.raw, + ) + }; + check_status(status, "gpu_matrix_gauss_samp_gq_arb_base"); + out + } + + pub(crate) fn sample_p1_full( + a_mat: &Self, + b_mat: &Self, + d_mat: &Self, + tp2: &Self, + sigma: f64, + s: f64, + dgg_stddev: f64, + seed: u64, + ) -> Self { + debug_assert_eq!(a_mat.params, b_mat.params, "A/B params mismatch"); + debug_assert_eq!(a_mat.params, d_mat.params, "A/D params mismatch"); + debug_assert_eq!(a_mat.params, tp2.params, "A/tp2 params mismatch"); + debug_assert_eq!(a_mat.nrow, a_mat.ncol, "A must be square"); + debug_assert_eq!(b_mat.nrow, a_mat.nrow, "B row size mismatch"); + debug_assert_eq!(b_mat.ncol, a_mat.ncol, "B col size mismatch"); + debug_assert_eq!(d_mat.nrow, a_mat.nrow, "D row size mismatch"); + debug_assert_eq!(d_mat.ncol, a_mat.ncol, "D col size mismatch"); + debug_assert_eq!(tp2.nrow, 2 * a_mat.nrow, "tp2 must have 2d rows"); + let out = Self::new_empty(&tp2.params, tp2.nrow, tp2.ncol); + if tp2.nrow == 0 || tp2.ncol == 0 { + return out; + } + let status = unsafe { + gpu_matrix_sample_p1_full( + a_mat.raw, b_mat.raw, d_mat.raw, tp2.raw, sigma, s, dgg_stddev, seed, out.raw, + ) + }; + check_status(status, "gpu_matrix_sample_p1_full"); + out + } + fn store_rns_bytes(&self, bytes_out: &mut [u8], bytes_per_poly: usize) { if bytes_out.is_empty() || bytes_per_poly == 0 { return; @@ -611,14 +701,14 @@ impl PolyMatrix for GpuDCRTPolyMatrix { } fn gadget_matrix(params: &::Params, size: usize) -> Self { - let cpu_params = DCRTPolyParams::new( - params.ring_dimension(), - params.crt_depth(), - params.crt_bits(), - params.base_bits(), - ); - let cpu_matrix = super::dcrt_poly::DCRTPolyMatrix::gadget_matrix(&cpu_params, size); - Self::from_cpu_matrix(params, &cpu_matrix) + if size == 0 { + return Self::new_zero(params, 0, 0); + } + let log_base_q = params.modulus_digits(); + let out = Self::new_empty(params, size, size * log_base_q); + let status = unsafe { gpu_matrix_fill_gadget(out.raw, params.base_bits()) }; + check_status(status, "gpu_matrix_fill_gadget"); + out } fn decompose(&self) -> Self { @@ -635,9 +725,22 @@ impl PolyMatrix for GpuDCRTPolyMatrix { &self, new_modulus: &<::Params as PolyParams>::Modulus, ) -> Self { - let cpu_matrix = self.to_cpu_matrix(); - let switched = cpu_matrix.modulus_switch(new_modulus); - Self::from_cpu_matrix(&self.params, &switched) + let polys = parallel_iter!(0..self.nrow) + .map(|i| { + parallel_iter!(0..self.ncol) + .map(|j| { + let coeffs = self.entry(i, j); + let switched_coeffs = coeffs + .coeffs() + .into_iter() + .map(|c| c.modulus_switch(new_modulus.clone())) + .collect::>(); + GpuDCRTPoly::from_coeffs(&self.params, &switched_coeffs) + }) + .collect::>() + }) + .collect::>(); + Self::from_poly_vec(&self.params, polys) } fn mul_tensor_identity(&self, other: &Self, identity_size: usize) -> Self { @@ -1051,6 +1154,7 @@ mod tests { poly::dcrt::gpu::gpu_device_sync, }; use num_bigint::BigUint; + use rand::{Rng, rng}; use std::sync::Arc; fn gpu_test_params() -> DCRTPolyParams { @@ -1165,6 +1269,157 @@ mod tests { assert_eq!(matrix, expected_matrix); } + #[test] + #[sequential] + fn test_gpu_matrix_gauss_samp_gq_arb_base_relation() { + gpu_device_sync(); + let params = DCRTPolyParams::new(128, 2, 16, 8); + let gpu_params = gpu_params_from_cpu(¶ms); + + let value_a = 5usize; + let value_b = 9usize; + let matrix = GpuDCRTPolyMatrix::from_poly_vec( + &gpu_params, + vec![vec![ + GpuDCRTPoly::from_usize_to_constant(&gpu_params, value_a), + GpuDCRTPoly::from_usize_to_constant(&gpu_params, value_b), + ]], + ); + let base = 1u32 << gpu_params.base_bits(); + let c = (base as f64 + 1.0) * 4.578; + let gadget = GpuDCRTPolyMatrix::gadget_matrix(&gpu_params, matrix.row_size()); + for offset in 0..16u64 { + let sampled = + matrix.gauss_samp_gq_arb_base(c, 4.578, 0x1234_5678_9abc_def0u64 + offset); + let reconstructed = &gadget * &sampled; + assert_eq!(reconstructed, matrix); + } + + let modulus = gpu_params.modulus(); + let varied_coeffs = (0..gpu_params.ring_dimension() as usize) + .map(|i| { + let value = ((i as u64) * 7919u64 + 12345u64) as u32; + FinRingElem::new(value, modulus.clone()) + }) + .collect::>(); + let varied_poly = GpuDCRTPoly::from_coeffs(&gpu_params, &varied_coeffs); + let varied_matrix = GpuDCRTPolyMatrix::from_poly_vec(&gpu_params, vec![vec![varied_poly]]); + let varied_gadget = GpuDCRTPolyMatrix::gadget_matrix(&gpu_params, 1); + for offset in 0..16u64 { + let sampled = + varied_matrix.gauss_samp_gq_arb_base(c, 4.578, 0x00de_adbe_efu64 + offset); + let reconstructed = &varied_gadget * &sampled; + assert_eq!(reconstructed, varied_matrix); + } + + let wide_matrix = GpuDCRTPolyMatrix::from_poly_vec( + &gpu_params, + vec![ + vec![ + GpuDCRTPoly::from_usize_to_constant(&gpu_params, 17), + GpuDCRTPoly::from_usize_to_constant(&gpu_params, 345), + ], + vec![ + GpuDCRTPoly::from_usize_to_constant(&gpu_params, 777), + GpuDCRTPoly::from_usize_to_constant(&gpu_params, 1201), + ], + vec![ + GpuDCRTPoly::from_usize_to_constant(&gpu_params, 4095), + GpuDCRTPoly::from_usize_to_constant(&gpu_params, 65535), + ], + ], + ); + let wide_gadget = GpuDCRTPolyMatrix::gadget_matrix(&gpu_params, wide_matrix.row_size()); + for offset in 0..16u64 { + let sampled = + wide_matrix.gauss_samp_gq_arb_base(c, 4.578, 0x55aa_aa55_1357_2468u64 + offset); + let reconstructed = &wide_gadget * &sampled; + assert_eq!(reconstructed, wide_matrix); + } + + let mut prng = rng(); + let random_matrix_vec = (0..3) + .map(|_| { + (0..3) + .map(|_| { + let coeffs = (0..gpu_params.ring_dimension() as usize) + .map(|_| FinRingElem::new(prng.random::(), modulus.clone())) + .collect::>(); + GpuDCRTPoly::from_coeffs(&gpu_params, &coeffs) + }) + .collect::>() + }) + .collect::>(); + let random_matrix = GpuDCRTPolyMatrix::from_poly_vec(&gpu_params, random_matrix_vec); + let random_gadget = GpuDCRTPolyMatrix::gadget_matrix(&gpu_params, random_matrix.row_size()); + for offset in 0..8u64 { + let sampled = + random_matrix.gauss_samp_gq_arb_base(c, 4.578, 0x0f0f_f0f0_2468_1357u64 + offset); + let reconstructed = &random_gadget * &sampled; + if reconstructed != random_matrix { + let sampled_cpu = sampled.to_cpu_matrix(); + let src_cpu = random_matrix.to_cpu_matrix(); + let rows = random_matrix.row_size(); + let cols = random_matrix.col_size(); + let depth = gpu_params.crt_depth(); + let digits_per_tower = + gpu_params.crt_bits().div_ceil(gpu_params.base_bits() as usize); + let log_base_q = gpu_params.modulus_digits(); + let base_u64 = 1u64 << gpu_params.base_bits(); + let moduli = gpu_params.moduli().to_vec(); + let moduli_big = moduli.iter().map(|q| BigUint::from(*q)).collect::>(); + let mut violation = String::new(); + + 'search: for row in 0..rows { + for col in 0..cols { + let src_poly = src_cpu.entry(row, col); + let src_coeffs = src_poly.coeffs(); + for tower in 0..depth { + let q = moduli[tower]; + let q_big = &moduli_big[tower]; + for coeff_idx in 0..gpu_params.ring_dimension() as usize { + let src_res = (&*src_coeffs[coeff_idx].value() % q_big) + .to_u64_digits() + .first() + .copied() + .unwrap_or(0); + let mut accum = 0u64; + let mut base_pow = 1u64 % q; + for digit in 0..digits_per_tower { + let sampled_row = + row * log_base_q + tower * digits_per_tower + digit; + let digit_poly = sampled_cpu.entry(sampled_row, col); + let digit_coeff = + digit_poly.coeffs()[coeff_idx].value().clone(); + let digit_res = (&digit_coeff % q_big) + .to_u64_digits() + .first() + .copied() + .unwrap_or(0); + let term = ((u128::from(base_pow) * u128::from(digit_res)) % + u128::from(q)) + as u64; + accum = (accum + term) % q; + base_pow = ((u128::from(base_pow) * u128::from(base_u64)) % + u128::from(q)) + as u64; + } + if accum != src_res { + violation = format!( + "relation violated: offset={offset}, row={row}, col={col}, tower={tower}, coeff={coeff_idx}, lhs={accum}, rhs={src_res}, q={q}" + ); + break 'search; + } + } + } + } + } + + panic!("gauss_samp reconstruction mismatch; {violation}"); + } + } + } + #[test] #[sequential] fn test_gpu_matrix_basic_operations() { diff --git a/src/poly/dcrt/gpu.rs b/src/poly/dcrt/gpu.rs index 5eb1c33..e82e000 100644 --- a/src/poly/dcrt/gpu.rs +++ b/src/poly/dcrt/gpu.rs @@ -112,6 +112,11 @@ unsafe extern "C" { a: *const GpuPolyOpaque, b: *const GpuPolyOpaque, ) -> c_int; + fn gpu_poly_equal( + lhs: *const GpuPolyOpaque, + rhs: *const GpuPolyOpaque, + out_equal: *mut c_int, + ) -> c_int; fn gpu_poly_decompose_base( src: *const GpuPolyOpaque, base_bits: u32, @@ -169,6 +174,11 @@ unsafe extern "C" { lhs: *const GpuMatrixOpaque, rhs: *const GpuMatrixOpaque, ) -> c_int; + pub(crate) fn gpu_matrix_equal( + lhs: *const GpuMatrixOpaque, + rhs: *const GpuMatrixOpaque, + out_equal: *mut c_int, + ) -> c_int; pub(crate) fn gpu_matrix_mul_timed( out: *mut GpuMatrixOpaque, lhs: *const GpuMatrixOpaque, @@ -190,11 +200,37 @@ unsafe extern "C" { rows: usize, cols: usize, ) -> c_int; + pub(crate) fn gpu_matrix_fill_gadget(out: *mut GpuMatrixOpaque, base_bits: u32) -> c_int; pub(crate) fn gpu_matrix_decompose_base( src: *const GpuMatrixOpaque, base_bits: u32, out: *mut GpuMatrixOpaque, ) -> c_int; + pub(crate) fn gpu_matrix_gauss_samp_gq_arb_base( + src: *const GpuMatrixOpaque, + base_bits: u32, + c: f64, + dgg_stddev: f64, + seed: u64, + out: *mut GpuMatrixOpaque, + ) -> c_int; + pub(crate) fn gpu_matrix_sample_p1_full( + a_mat: *const GpuMatrixOpaque, + b_mat: *const GpuMatrixOpaque, + d_mat: *const GpuMatrixOpaque, + tp2: *const GpuMatrixOpaque, + sigma: f64, + s: f64, + dgg_stddev: f64, + seed: u64, + out: *mut GpuMatrixOpaque, + ) -> c_int; + pub(crate) fn gpu_matrix_sample_distribution( + out: *mut GpuMatrixOpaque, + dist_type: c_int, + sigma: f64, + seed: u64, + ) -> c_int; fn gpu_poly_ntt(poly: *mut GpuPolyOpaque, batch: c_int) -> c_int; fn gpu_poly_intt(poly: *mut GpuPolyOpaque, batch: c_int) -> c_int; @@ -210,6 +246,10 @@ unsafe extern "C" { pub(crate) const GPU_POLY_FORMAT_COEFF: c_int = 0; pub(crate) const GPU_POLY_FORMAT_EVAL: c_int = 1; +pub(crate) const GPU_MATRIX_DIST_UNIFORM: c_int = 0; +pub(crate) const GPU_MATRIX_DIST_GAUSS: c_int = 1; +pub(crate) const GPU_MATRIX_DIST_BIT: c_int = 2; +pub(crate) const GPU_MATRIX_DIST_TERNARY: c_int = 3; pub(crate) fn last_error_string() -> String { unsafe { @@ -959,10 +999,24 @@ impl Drop for GpuDCRTPoly { impl PartialEq for GpuDCRTPoly { fn eq(&self, other: &Self) -> bool { - if self.params.as_ref() != other.params.as_ref() { + if self.params.as_ref() != other.params.as_ref() || self.level != other.level { return false; } - self.coeffs() == other.coeffs() + if self.raw == other.raw { + return true; + } + let mut out_equal: c_int = 0; + if self.is_ntt == other.is_ntt { + let status = + unsafe { gpu_poly_equal(self.raw, other.raw, &mut out_equal as *mut c_int) }; + check_status(status, "gpu_poly_equal"); + return out_equal != 0; + } + let lhs = self.ensure_coeff_domain(); + let rhs = other.ensure_coeff_domain(); + let status = unsafe { gpu_poly_equal(lhs.raw, rhs.raw, &mut out_equal as *mut c_int) }; + check_status(status, "gpu_poly_equal"); + out_equal != 0 } } @@ -1453,6 +1507,20 @@ mod tests { ); } + #[test] + #[sequential] + fn test_gpu_dcrtpoly_partial_eq_across_domains() { + gpu_device_sync(); + let params = gpu_test_params(); + let gpu_params = gpu_params_from_cpu(¶ms); + let sampler = DCRTPolyUniformSampler::new(); + let cpu_poly = sampler.sample_poly(¶ms, &DistType::FinRingDist); + let coeff_poly = gpu_poly_from_cpu(&cpu_poly, &gpu_params); + let mut eval_poly = coeff_poly.clone(); + eval_poly.ntt_in_place(); + assert_eq!(coeff_poly, eval_poly, "PartialEq should match across coeff/eval domains"); + } + #[test] #[sequential] fn test_gpu_dcrtpoly_decompose() { diff --git a/src/sampler/trapdoor/gpu.rs b/src/sampler/trapdoor/gpu.rs new file mode 100644 index 0000000..6408f85 --- /dev/null +++ b/src/sampler/trapdoor/gpu.rs @@ -0,0 +1,402 @@ +use super::DCRTTrapdoor; +use crate::{ + matrix::{ + PolyMatrix, + gpu_dcrt_poly::{GpuDCRTPolyMatrix, GpuMatrixSampleDist}, + }, + poly::{ + Poly, PolyParams, + dcrt::{gpu::GpuDCRTPolyParams, params::DCRTPolyParams}, + }, + sampler::{DistType, PolyTrapdoorSampler}, +}; +use rand::{Rng, rng}; + +const SIGMA: f64 = 4.578; +const SPECTRAL_CONSTANT: f64 = 1.8; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GpuDCRTTrapdoor { + pub r: GpuDCRTPolyMatrix, + pub e: GpuDCRTPolyMatrix, + pub a_mat: GpuDCRTPolyMatrix, + pub b_mat: GpuDCRTPolyMatrix, + pub d_mat: GpuDCRTPolyMatrix, + pub re: GpuDCRTPolyMatrix, +} + +impl GpuDCRTTrapdoor { + pub fn new(params: &GpuDCRTPolyParams, size: usize, sigma: f64) -> Self { + let log_base_q = params.modulus_digits(); + let dist = DistType::GaussDist { sigma }; + let r = sample_gpu_matrix_native(params, size, size * log_base_q, dist); + let e = sample_gpu_matrix_native(params, size, size * log_base_q, dist); + let a_mat = &r * &r.transpose(); // d x d + let b_mat = &r * &e.transpose(); // d x d + let d_mat = &e * &e.transpose(); // d x d + let re = r.concat_rows(&[&e]); + Self { r, e, a_mat, b_mat, d_mat, re } + } + + pub fn to_compact_bytes(&self) -> Vec { + self.to_cpu_trapdoor().to_compact_bytes() + } + + pub fn from_compact_bytes(params: &GpuDCRTPolyParams, bytes: &[u8]) -> Option { + let cpu_params = cpu_params_from_gpu(params); + let cpu = DCRTTrapdoor::from_compact_bytes(&cpu_params, bytes)?; + Some(Self { + r: GpuDCRTPolyMatrix::from_cpu_matrix(params, &cpu.r), + e: GpuDCRTPolyMatrix::from_cpu_matrix(params, &cpu.e), + a_mat: GpuDCRTPolyMatrix::from_cpu_matrix(params, &cpu.a_mat), + b_mat: GpuDCRTPolyMatrix::from_cpu_matrix(params, &cpu.b_mat), + d_mat: GpuDCRTPolyMatrix::from_cpu_matrix(params, &cpu.d_mat), + re: GpuDCRTPolyMatrix::from_cpu_matrix(params, &cpu.re), + }) + } + + pub(crate) fn to_cpu_trapdoor(&self) -> DCRTTrapdoor { + DCRTTrapdoor { + r: self.r.to_cpu_matrix(), + e: self.e.to_cpu_matrix(), + a_mat: self.a_mat.to_cpu_matrix(), + b_mat: self.b_mat.to_cpu_matrix(), + d_mat: self.d_mat.to_cpu_matrix(), + re: self.re.to_cpu_matrix(), + } + } +} + +#[derive(Debug, Clone)] +pub struct GpuDCRTPolyTrapdoorSampler { + sigma: f64, + base: u32, + c: f64, +} + +impl PolyTrapdoorSampler for GpuDCRTPolyTrapdoorSampler { + type M = GpuDCRTPolyMatrix; + type Trapdoor = GpuDCRTTrapdoor; + + fn new(params: &<::P as Poly>::Params, sigma: f64) -> Self { + let base = 1 << params.base_bits(); + let c = (base as f64 + 1.0) * SIGMA; + Self { sigma, base, c } + } + + fn trapdoor( + &self, + params: &<::P as Poly>::Params, + size: usize, + ) -> (Self::Trapdoor, Self::M) { + let trapdoor = GpuDCRTTrapdoor::new(params, size, self.sigma); + let a_bar = sample_gpu_matrix_native(params, size, size, DistType::FinRingDist); + let g = GpuDCRTPolyMatrix::gadget_matrix(params, size); + let a0 = a_bar.concat_columns(&[&GpuDCRTPolyMatrix::identity(params, size, None)]); + let a1 = &g - &(&a_bar * &trapdoor.r + &trapdoor.e); + let a = a0.concat_columns(&[&a1]); + (trapdoor, a) + } + + fn trapdoor_to_bytes(trapdoor: &Self::Trapdoor) -> Vec { + trapdoor.to_compact_bytes() + } + + fn trapdoor_from_bytes( + params: &<::P as Poly>::Params, + bytes: &[u8], + ) -> Option { + GpuDCRTTrapdoor::from_compact_bytes(params, bytes) + } + + fn preimage( + &self, + params: &<::P as Poly>::Params, + trapdoor: &Self::Trapdoor, + public_matrix: &Self::M, + target: &Self::M, + ) -> Self::M { + let d = public_matrix.row_size(); + let target_cols = target.col_size(); + debug_assert_eq!( + target.row_size(), + d, + "Target matrix should have the same number of rows as the public matrix", + ); + + let n = params.ring_dimension() as usize; + let k = params.modulus_digits(); + let s = SPECTRAL_CONSTANT * + (self.base as f64 + 1.0) * + SIGMA * + SIGMA * + (((d * n * k) as f64).sqrt() + ((2 * n) as f64).sqrt() + 4.7); + let dgg_large_std = (s * s - self.c * self.c).sqrt(); + let p_hat = sample_pert_square_mat_gpu_native( + params, + trapdoor, + s, + self.c, + self.sigma, + dgg_large_std, + target_cols, + ); + + let perturbed_syndrome = target - &(public_matrix * &p_hat); + // OpenFHE-equivalent GaussSampGqArbBase path on GPU: + // this keeps the perturbation + gadget preimage step randomized. + let mut rng = rng(); + let z_seed: u64 = rng.random(); + let z_hat_mat = perturbed_syndrome.gauss_samp_gq_arb_base(self.c, self.sigma, z_seed); + + let r_z_hat = &trapdoor.r * &z_hat_mat; + let e_z_hat = &trapdoor.e * &z_hat_mat; + let z_hat_former = (p_hat.slice_rows(0, d) + r_z_hat) + .concat_rows(&[&(p_hat.slice_rows(d, 2 * d) + e_z_hat)]); + let z_hat_latter = p_hat.slice_rows(2 * d, d * (k + 2)) + z_hat_mat; + z_hat_former.concat_rows(&[&z_hat_latter]) + } + + fn preimage_extend( + &self, + params: &<::P as Poly>::Params, + trapdoor: &Self::Trapdoor, + public_matrix: &Self::M, + ext_matrix: &Self::M, + target: &Self::M, + ) -> Self::M { + let d = public_matrix.row_size(); + let ext_ncol = ext_matrix.col_size(); + let target_ncol = target.col_size(); + let n = params.ring_dimension() as usize; + let k = params.modulus_digits(); + let s = SPECTRAL_CONSTANT * + (self.base as f64 + 1.0) * + SIGMA * + SIGMA * + (((d * n * k) as f64).sqrt() + ((2 * n) as f64).sqrt() + 4.7); + + let dist = DistType::GaussDist { sigma: s }; + let preimage_right = sample_gpu_matrix_native(params, ext_ncol, target_ncol, dist); + let t = target - &(ext_matrix * &preimage_right); + let preimage_left = self.preimage(params, trapdoor, public_matrix, &t); + preimage_left.concat_rows(&[&preimage_right]) + } +} + +fn cpu_params_from_gpu(params: &GpuDCRTPolyParams) -> DCRTPolyParams { + DCRTPolyParams::new( + params.ring_dimension(), + params.crt_depth(), + params.crt_bits(), + params.base_bits(), + ) +} + +fn sample_pert_square_mat_gpu_native( + params: &GpuDCRTPolyParams, + trapdoor: &GpuDCRTTrapdoor, + s: f64, + c: f64, + dgg_stddev: f64, + sigma_large: f64, + total_ncol: usize, +) -> GpuDCRTPolyMatrix { + let d = trapdoor.r.row_size(); + let dk = trapdoor.r.col_size(); + let num_blocks = total_ncol.div_ceil(d); + let padded_ncol = num_blocks * d; + let padding_ncol = padded_ncol - total_ncol; + + // p2 is sampled directly on GPU as in the Karney branch of OpenFHE. + let p2 = sample_gpu_matrix_native( + params, + dk, + padded_ncol, + DistType::GaussDist { sigma: sigma_large }, + ); + let tp2 = &trapdoor.re * &p2; + + // Keep perturbation generation on device: this sampler uses the full + // 2d x 2d covariance induced by (A, B, D) and Tp2. + let mut prng = rng(); + let p1_seed: u64 = prng.random(); + let p1 = GpuDCRTPolyMatrix::sample_p1_full( + &trapdoor.a_mat, + &trapdoor.b_mat, + &trapdoor.d_mat, + &tp2, + c, + s, + dgg_stddev, + p1_seed, + ); + + let mut p_hat = p1.concat_rows(&[&p2]); + if padding_ncol > 0 { + p_hat = p_hat.slice_columns(0, total_ncol); + } + p_hat +} + +fn sample_gpu_matrix_native( + params: &GpuDCRTPolyParams, + nrow: usize, + ncol: usize, + dist: DistType, +) -> GpuDCRTPolyMatrix { + if nrow == 0 || ncol == 0 { + return GpuDCRTPolyMatrix::zero(params, nrow, ncol); + } + let mut prng = rng(); + let seed: u64 = prng.random(); + match dist { + DistType::FinRingDist => GpuDCRTPolyMatrix::sample_distribution( + params, + nrow, + ncol, + GpuMatrixSampleDist::Uniform, + 0.0, + seed, + ), + DistType::GaussDist { sigma } => GpuDCRTPolyMatrix::sample_distribution( + params, + nrow, + ncol, + GpuMatrixSampleDist::Gauss, + sigma, + seed, + ), + DistType::BitDist => GpuDCRTPolyMatrix::sample_distribution( + params, + nrow, + ncol, + GpuMatrixSampleDist::Bit, + 0.0, + seed, + ), + DistType::TernaryDist => GpuDCRTPolyMatrix::sample_distribution( + params, + nrow, + ncol, + GpuMatrixSampleDist::Ternary, + 0.0, + seed, + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + __PAIR, __TestState, + matrix::PolyMatrix, + poly::{ + PolyParams, + dcrt::{gpu::gpu_device_sync, params::DCRTPolyParams}, + }, + }; + use sequential_test::sequential; + + const SIGMA: f64 = 4.578; + + fn gpu_test_params() -> DCRTPolyParams { + DCRTPolyParams::new(128, 2, 16, 8) + } + + fn gpu_params_from_cpu(params: &DCRTPolyParams) -> GpuDCRTPolyParams { + let (moduli, _, _) = params.to_crt(); + GpuDCRTPolyParams::new(params.ring_dimension(), moduli, params.base_bits()) + } + + #[test] + #[sequential] + fn test_gpu_trapdoor_generation() { + gpu_device_sync(); + let size: usize = 3; + let cpu_params = gpu_test_params(); + let params = gpu_params_from_cpu(&cpu_params); + let trapdoor_sampler = GpuDCRTPolyTrapdoorSampler::new(¶ms, SIGMA); + + let (trapdoor, public_matrix) = trapdoor_sampler.trapdoor(¶ms, size); + + let expected_rows = size; + let expected_cols = (params.modulus_digits() + 2) * size; + assert_eq!(public_matrix.row_size(), expected_rows); + assert_eq!(public_matrix.col_size(), expected_cols); + + let k = params.modulus_digits(); + let identity = GpuDCRTPolyMatrix::identity(¶ms, size * k, None); + let trapdoor_matrix = trapdoor.r.concat_rows(&[&trapdoor.e, &identity]); + let muled = public_matrix * trapdoor_matrix; + let gadget_matrix = GpuDCRTPolyMatrix::gadget_matrix(¶ms, size); + assert_eq!(muled, gadget_matrix); + } + + #[test] + #[sequential] + fn test_gpu_trapdoor_round_trip_bytes() { + gpu_device_sync(); + let size: usize = 3; + let cpu_params = gpu_test_params(); + let params = gpu_params_from_cpu(&cpu_params); + let trapdoor_sampler = GpuDCRTPolyTrapdoorSampler::new(¶ms, SIGMA); + let (trapdoor, _public_matrix) = trapdoor_sampler.trapdoor(¶ms, size); + + let bytes = + ::trapdoor_to_bytes(&trapdoor); + let decoded = ::trapdoor_from_bytes( + ¶ms, &bytes, + ) + .expect("trapdoor bytes should decode"); + let reencoded = + ::trapdoor_to_bytes(&decoded); + assert_eq!( + bytes, reencoded, + "trapdoor compact bytes should be stable across decode/encode" + ); + } + + #[test] + #[sequential] + fn test_gpu_preimage_generation_square() { + gpu_device_sync(); + let size = 3usize; + let cpu_params = gpu_test_params(); + let params = gpu_params_from_cpu(&cpu_params); + let trapdoor_sampler = GpuDCRTPolyTrapdoorSampler::new(¶ms, SIGMA); + let (trapdoor, public_matrix) = trapdoor_sampler.trapdoor(¶ms, size); + let target = sample_gpu_matrix_native(¶ms, size, size, DistType::FinRingDist); + + let preimage = trapdoor_sampler.preimage(¶ms, &trapdoor, &public_matrix, &target); + let product = &public_matrix * &preimage; + assert_eq!(product, target); + } + + #[test] + #[sequential] + fn test_gpu_preimage_generation_square_not_plain_gadget_solution() { + gpu_device_sync(); + let size = 3usize; + let cpu_params = gpu_test_params(); + let params = gpu_params_from_cpu(&cpu_params); + let trapdoor_sampler = GpuDCRTPolyTrapdoorSampler::new(¶ms, SIGMA); + let (trapdoor, public_matrix) = trapdoor_sampler.trapdoor(¶ms, size); + let target = sample_gpu_matrix_native(¶ms, size, size, DistType::FinRingDist); + + // Deterministic gadget preimage baseline: + // z_plain = [R*z; E*z; z], where z = decompose(target). + let z_plain = target.decompose(); + let z_plain_former = (&trapdoor.r * &z_plain).concat_rows(&[&(&trapdoor.e * &z_plain)]); + let z_plain_full = z_plain_former.concat_rows(&[&z_plain]); + assert_eq!(&public_matrix * &z_plain_full, target); + + let sampled = trapdoor_sampler.preimage(¶ms, &trapdoor, &public_matrix, &target); + assert_eq!(&public_matrix * &sampled, target); + assert_ne!( + sampled, z_plain_full, + "preimage sampler should not collapse to the plain deterministic gadget preimage" + ); + } +} diff --git a/src/sampler/trapdoor/mod.rs b/src/sampler/trapdoor/mod.rs index 2291603..fd46e91 100644 --- a/src/sampler/trapdoor/mod.rs +++ b/src/sampler/trapdoor/mod.rs @@ -10,6 +10,8 @@ use crate::{ poly::{PolyParams, dcrt::params::DCRTPolyParams}, sampler::{DistType, PolyUniformSampler, uniform::DCRTPolyUniformSampler}, }; +#[cfg(feature = "gpu")] +pub use gpu::{GpuDCRTPolyTrapdoorSampler, GpuDCRTTrapdoor}; use openfhe::ffi::{FormatMatrixCoefficient, SampleP1ForPertMat}; use rayon::iter::ParallelIterator; pub use sampler::DCRTPolyTrapdoorSampler; @@ -20,6 +22,8 @@ use std::{ use tracing::debug; use utils::{gen_dgg_int_vec, gen_int_karney, split_int64_mat_to_elems}; +#[cfg(feature = "gpu")] +pub mod gpu; pub mod sampler; pub mod utils; diff --git a/src/sampler/trapdoor/sampler.rs b/src/sampler/trapdoor/sampler.rs index 4ac6555..48b706e 100644 --- a/src/sampler/trapdoor/sampler.rs +++ b/src/sampler/trapdoor/sampler.rs @@ -14,11 +14,16 @@ use crate::{ }; use openfhe::ffi::DCRTGaussSampGqArbBase; use rayon::iter::ParallelIterator; -use std::{ops::Range, time::Instant}; +use std::{ + ops::Range, + sync::{Mutex, OnceLock}, + time::Instant, +}; use tracing::debug; const SIGMA: f64 = 4.578; const SPECTRAL_CONSTANT: f64 = 1.8; +static GAUSS_SAMP_GQ_ARB_BASE_LOCK: OnceLock> = OnceLock::new(); #[derive(Debug, Clone)] pub struct DCRTPolyTrapdoorSampler { @@ -237,17 +242,22 @@ pub(crate) fn gauss_samp_gq_arb_base( let depth = params.crt_depth(); let k_res_bits = params.crt_bits(); let k_res_digits = params.modulus_digits() / depth; - let result = DCRTGaussSampGqArbBase( - syndrome.get_poly(), - c, - n, - depth, - k_res_bits, - k_res_digits, - base as i64, - sigma, - tower_idx, - ); + // OpenFHE's GaussSampGqArbBase can race across threads depending on backend state. + // Keep this FFI call serialized for stability. + let result = { + let _guard = GAUSS_SAMP_GQ_ARB_BASE_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap(); + DCRTGaussSampGqArbBase( + syndrome.get_poly(), + c, + n, + depth, + k_res_bits, + k_res_digits, + base as i64, + sigma, + tower_idx, + ) + }; debug_assert_eq!(result.len(), n as usize * k_res_digits); // let mut matrix = I64Matrix::new_empty(&I64MatrixParams, k_res, n as usize); parallel_iter!(0..k_res_digits) diff --git a/src/sampler/trapdoor/utils.rs b/src/sampler/trapdoor/utils.rs index a5d33f1..fecd6b1 100644 --- a/src/sampler/trapdoor/utils.rs +++ b/src/sampler/trapdoor/utils.rs @@ -13,11 +13,18 @@ use crate::{ use openfhe::ffi::GenerateIntegerKarney; use rand::{Rng, distr::Uniform, rng}; use rayon::prelude::*; -use std::ops::Range; +use std::{ + ops::Range, + sync::{Mutex, OnceLock}, +}; + +static KARNEY_SAMPLER_LOCK: OnceLock> = OnceLock::new(); pub(crate) fn gen_int_karney(mean: f64, stddev: f64) -> i64 { - let out = GenerateIntegerKarney(mean, stddev); - out + // OpenFHE's Karney sampler can touch shared global state in some builds. + // Serialize calls to avoid rare races while preserving the exact sampler primitive. + let _guard = KARNEY_SAMPLER_LOCK.get_or_init(|| Mutex::new(())).lock().unwrap(); + GenerateIntegerKarney(mean, stddev) } fn find_in_vec(vec: &[f64], search: f64) -> u32 {