Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 66 additions & 18 deletions crates/gpu-prover/src/cuda_bindings/async_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use bellman::PrimeField;
use core::ops::Range;
use std::io::{Read, Write};

pub struct AsyncVec<T, #[cfg(feature = "allocator")] A: Allocator = CudaAllocator> {
pub struct AsyncVec<T: Copy, #[cfg(feature = "allocator")] A: Allocator = CudaAllocator> {
#[cfg(feature = "allocator")]
pub values: Option<Vec<T, A>>,
values: Option<Vec<T, A>>,
#[cfg(not(feature = "allocator"))]
pub values: Option<Vec<T>>,
values: Option<Vec<T>>,
pub(crate) data_is_set: bool,
pub(crate) read_event: Event,
pub(crate) write_event: Event,
}
Expand All @@ -17,10 +18,10 @@ use std::fmt;
macro_rules! impl_async_vec {
(impl AsyncVec $inherent:tt) => {
#[cfg(feature = "allocator")]
impl<T, A: Allocator + Default> AsyncVec<T, A> $inherent
impl<T: Copy, A: Allocator + Default> AsyncVec<T, A> $inherent

#[cfg(not(feature = "allocator"))]
impl<T> AsyncVec<T> $inherent
impl<T: Copy> AsyncVec<T> $inherent
};
}

Expand All @@ -37,17 +38,20 @@ impl_async_vec! {

Self {
values: Some(values),
data_is_set: false,
read_event: Event::new(),
write_event: Event::new(),
}
}

pub fn get_values(&self) -> GpuResult<&[T]> {
assert!(self.data_is_set, "AsyncVec should be filled with some data");
self.write_event.sync()?;
Ok(self.values.as_ref().expect("async_vec inner is none"))
}

pub fn get_values_mut(&mut self) -> GpuResult<&mut [T]> {
assert!(self.data_is_set, "AsyncVec should be filled with some data");
self.read_event.sync()?;
self.write_event.sync()?;
Ok(self.values.as_mut().expect("async_vec inner is none"))
Expand All @@ -60,6 +64,7 @@ impl_async_vec! {
this_range: Range<usize>,
other_range: Range<usize>,
) -> GpuResult<()> {
assert!(self.data_is_set, "AsyncVec should be filled with some data");
assert_eq!(this_range.len(), other_range.len());
let length = std::mem::size_of::<T>() * this_range.len();
set_device(ctx.device_id())?;
Expand All @@ -76,6 +81,7 @@ impl_async_vec! {
ctx.h2d_stream().inner,
)
};
other.data_is_set = true;

if result != 0 {
return Err(GpuError::AsyncH2DErr(result));
Expand All @@ -94,6 +100,7 @@ impl_async_vec! {
this_range: Range<usize>,
other_range: Range<usize>,
) -> GpuResult<()> {
assert!(other.data_is_set, "DeviceBuf should be filled with some data");
assert_eq!(this_range.len(), other_range.len());
let length = std::mem::size_of::<T>() * this_range.len();
set_device(ctx.device_id())?;
Expand All @@ -117,6 +124,7 @@ impl_async_vec! {

self.write_event.record(ctx.d2h_stream())?;
other.read_event.record(ctx.d2h_stream())?;
self.data_is_set = true;

Ok(())
}
Expand All @@ -126,6 +134,7 @@ impl_async_vec! {
}
#[cfg(feature = "allocator")]
pub fn into_inner(mut self) -> GpuResult<std::vec::Vec<T, A>> {
assert!(self.data_is_set, "AsyncVec should be filled with some data");
self.read_event.sync()?;
self.write_event.sync()?;

Expand All @@ -134,6 +143,7 @@ impl_async_vec! {

#[cfg(not(feature = "allocator"))]
pub fn into_inner(mut self) -> GpuResult<std::vec::Vec<T>> {
assert!(self.data_is_set, "AsyncVec should be filled with some data");
self.read_event.sync()?;
self.write_event.sync()?;

Expand All @@ -156,77 +166,112 @@ impl_async_vec! {
self.values.as_mut().expect("async_vec inner is none")[range].as_mut_ptr()
}

pub fn zeroize(&mut self){
let unit_len = std::mem::size_of::<T>();
let total_len = unit_len * self.len();
let dst = self.as_mut_ptr(0..self.len()) as *mut u8;
unsafe{std::ptr::write_bytes(dst, 0, total_len)};
pub fn fill(&mut self, value: T) -> GpuResult<()> {
self.read_event.sync()?;
self.write_event.sync()?;

self.values.as_mut().unwrap().fill(value);
self.data_is_set = true;

Ok(())
}

pub fn copy_from_slice(&mut self, src: &[T]) -> GpuResult<()> {
self.read_event.sync()?;
self.write_event.sync()?;

// copy_from_slice checks the equality of lengths
self.values.as_mut().unwrap().copy_from_slice(src);
self.data_is_set = true;

Ok(())
}

pub fn async_copy_from_slice(&mut self, worker: &Worker, src: &[T]) -> GpuResult<()>
where T: Send + Sync
{
self.read_event.sync()?;
self.write_event.sync()?;

// async_copy checks the equality of lengths
async_copy(
worker,
self.values.as_mut().unwrap(),
src,
);
self.data_is_set = true;

Ok(())
}
}
}

#[cfg(feature = "allocator")]
impl<T: fmt::Debug, A: Allocator + Default> fmt::Debug for AsyncVec<T, A> {
impl<T: fmt::Debug + Copy, A: Allocator + Default> fmt::Debug for AsyncVec<T, A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
assert!(self.data_is_set, "AsyncVec should be filled with some data");
f.debug_struct("AsyncVec")
.field("Values", &self.get_values().unwrap())
.finish()
}
}
#[cfg(not(feature = "allocator"))]
impl<T: fmt::Debug> fmt::Debug for AsyncVec<T> {
impl<T: fmt::Debug + Copy> fmt::Debug for AsyncVec<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
assert!(self.data_is_set, "AsyncVec should be filled with some data");
f.debug_struct("AsyncVec")
.field("Values", &self.get_values().unwrap())
.finish()
}
}

#[cfg(feature = "allocator")]
impl<T, A: Allocator> From<Vec<T, A>> for AsyncVec<T, A> {
impl<T: Copy, A: Allocator> From<Vec<T, A>> for AsyncVec<T, A> {
fn from(values: Vec<T, A>) -> Self {
Self {
values: Some(values),
data_is_set: true,
read_event: Event::new(),
write_event: Event::new(),
}
}
}
#[cfg(not(feature = "allocator"))]
impl<T> From<Vec<T>> for AsyncVec<T> {
impl<T: Copy> From<Vec<T>> for AsyncVec<T> {
fn from(values: Vec<T>) -> Self {
Self {
values: Some(values),
data_is_set: true,
read_event: Event::new(),
write_event: Event::new(),
}
}
}

#[cfg(feature = "allocator")]
impl<T, A: Allocator + Default> From<AsyncVec<T, A>> for Vec<T, A> {
impl<T: Copy, A: Allocator + Default> From<AsyncVec<T, A>> for Vec<T, A> {
fn from(vector: AsyncVec<T, A>) -> Self {
vector.into_inner().unwrap()
}
}

#[cfg(not(feature = "allocator"))]
impl<T> From<AsyncVec<T>> for Vec<T> {
impl<T: Copy> From<AsyncVec<T>> for Vec<T> {
fn from(vector: AsyncVec<T>) -> Self {
vector.into_inner().unwrap()
}
}

#[cfg(feature = "allocator")]
impl<T, A: Allocator> Drop for AsyncVec<T, A> {
impl<T: Copy, A: Allocator> Drop for AsyncVec<T, A> {
fn drop(&mut self) {
self.read_event.sync().unwrap();
self.write_event.sync().unwrap();
}
}

#[cfg(not(feature = "allocator"))]
impl<T> Drop for AsyncVec<T> {
impl<T: Copy> Drop for AsyncVec<T> {
fn drop(&mut self) {
self.read_event.sync().unwrap();
self.write_event.sync().unwrap();
Expand Down Expand Up @@ -263,6 +308,7 @@ impl_async_vec_for_field! {
}

pub fn to_bytes(&self, dst: &mut [u8]) -> GpuResult<()> {
assert!(self.data_is_set, "AsyncVec should be filled with some data");
let length = self.len();
let F_SIZE = F::zero().into_raw_repr().as_ref().len() * 8;
assert_eq!(length * F_SIZE, dst.len(), "Wrong destination length");
Expand Down Expand Up @@ -306,6 +352,8 @@ impl_async_vec_for_field! {
)
};

self.data_is_set = true;

Ok(())
}
}
Expand Down
6 changes: 6 additions & 0 deletions crates/gpu-prover/src/cuda_bindings/device_arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ impl DeviceBuf<Fr> {
);
let constant = constant.expect("constant should be Some in SetValue operation");

self.data_is_set = true;

ff_set_value(
self.as_mut_ptr(range) as *mut c_void,
&constant as *const Fr as *const c_void,
Expand All @@ -245,8 +247,10 @@ impl DeviceBuf<Fr> {
return Err(GpuError::ArithmeticErr(result));
}

assert!(self.data_is_set, "DeviceBuf should be filled with some data");
self.write_event.record(&ctx.exec_stream)?;
if let Some(other) = other {
assert!(other.data_is_set, "DeviceBuf should be filled with some data");
other.read_event.record(&ctx.exec_stream)?;
}

Expand All @@ -263,6 +267,8 @@ impl DeviceBuf<Fr> {
shift: usize,
inverse: bool,
) -> GpuResult<()> {
assert!(self.data_is_set, "DeviceBuf should be filled with some data");

assert!(
ctx.ff,
"ff is not set up on GpuContext with id {}",
Expand Down
14 changes: 11 additions & 3 deletions crates/gpu-prover/src/cuda_bindings/device_buf.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use super::*;
use core::ops::Range;

pub struct DeviceBuf<T> {
pub struct DeviceBuf<T: Copy> {
pub(crate) ptr: *mut T,
pub(crate) len: usize,
pub(crate) device_id: usize,

pub(crate) data_is_set: bool,
pub(crate) is_static_mem: bool,
pub(crate) is_freed: bool,

pub(crate) read_event: Event,
pub(crate) write_event: Event,
}

impl<T> DeviceBuf<T> {
impl<T: Copy> DeviceBuf<T> {
pub fn alloc_static(ctx: &GpuContext, len: usize) -> GpuResult<Self> {
set_device(ctx.device_id())?;
assert!(ctx.mem_pool.is_none(), "mem pool is allocated");
Expand All @@ -32,6 +33,7 @@ impl<T> DeviceBuf<T> {
len: len,
device_id: ctx.device_id(),

data_is_set: false,
is_static_mem: true,
is_freed: false,

Expand Down Expand Up @@ -64,6 +66,7 @@ impl<T> DeviceBuf<T> {
len: len,
device_id: ctx.device_id(),

data_is_set: false,
is_static_mem: false,
is_freed: false,

Expand Down Expand Up @@ -96,6 +99,7 @@ impl<T> DeviceBuf<T> {
len: len,
device_id: ctx.device_id(),

data_is_set: false,
is_static_mem: false,
is_freed: false,

Expand Down Expand Up @@ -127,6 +131,7 @@ impl<T> DeviceBuf<T> {
len: chunk_len,
device_id: self.device_id,

data_is_set: self.data_is_set,
is_static_mem: self.is_static_mem,
is_freed: true,

Expand Down Expand Up @@ -174,6 +179,7 @@ impl<T> DeviceBuf<T> {
ctx.exec_stream.wait(other.read_event())?;
ctx.exec_stream.wait(other.write_event())?;

assert!(self.data_is_set, "DeviceBuf should be filled with some data");
let result = unsafe {
bc_memcpy_async(
other.as_mut_ptr(other_range) as *mut c_void,
Expand All @@ -182,6 +188,7 @@ impl<T> DeviceBuf<T> {
ctx.exec_stream().inner,
)
};
other.data_is_set = true;

if result != 0 {
return Err(GpuError::AsyncH2DErr(result));
Expand Down Expand Up @@ -214,6 +221,7 @@ impl<T> DeviceBuf<T> {
length as u64,
ctx.h2d_stream().inner,
);
self.data_is_set = true;

if result != 0 {
return Err(GpuError::AsyncMemcopyErr(result));
Expand Down Expand Up @@ -287,7 +295,7 @@ impl<T> DeviceBuf<T> {
}
}

impl<T> Drop for DeviceBuf<T> {
impl<T: Copy> Drop for DeviceBuf<T> {
fn drop(&mut self) {
if !self.is_freed {
self.read_event.sync().unwrap();
Expand Down
Loading