diff --git a/Cargo.toml b/Cargo.toml index c8d815042f..1b7da7c402 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ candle-onnx = { path = "./candle-onnx", version = "0.9.1" } candle-transformers = { path = "./candle-transformers", version = "0.9.1" } clap = { version = "4.2.4", features = ["derive"] } criterion = { version = "0.5.1", default-features = false } -cudarc = { version = "0.16.3", features = [ +cudarc = { version = "0.17.3", features = [ "std", "cublas", "cublaslt", @@ -62,7 +62,7 @@ half = { version = "2.5.0", features = [ "use-intrinsics", "rand_distr", ] } -float8 = { git = "https://github.com/zackangelo/float8", branch = "cudarc_0_16", features = [ +float8 = { version = "0.4.2", features = [ "num-traits", "rand_distr", ] } diff --git a/candle-core/src/cuda_backend/graph.rs b/candle-core/src/cuda_backend/graph.rs new file mode 100644 index 0000000000..faf4063bdf --- /dev/null +++ b/candle-core/src/cuda_backend/graph.rs @@ -0,0 +1,191 @@ +use std::fmt; +use cudarc::driver::{CudaGraph, sys}; +use crate::{CudaDevice, Result, Shape, Error}; +use crate::cuda_backend::CudaError; + +pub struct CudaGraphHandle { + graph: Option, + device: CudaDevice, + captured_shape: Shape, +} + +impl fmt::Debug for CudaGraphHandle { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> { + #[derive(Debug)] + #[allow(dead_code)] + struct CudaGraphHandle<'a> { + device: &'a CudaDevice, + captured_shape: &'a Shape, + } + + let Self { + graph: _, + device, + captured_shape, + } = self; + + fmt::Debug::fmt(&CudaGraphHandle { device, captured_shape }, f) + } +} + +impl CudaGraphHandle { + pub fn capture( + device: &CudaDevice, + shape: &Shape, + capture_fn: F + ) -> Result + where + F: FnOnce() -> Result<()> + { + let stream = device.cuda_stream(); + + // Begin capture + let flags = sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED; + + stream.begin_capture(flags).map_err(|e| CudaError::from(e))?; + + // Execute the capture function + capture_fn()?; + + // End capture and create graph + let flags = sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH; + let graph = stream.end_capture(flags).map_err(|e| CudaError::from(e))?; + + Ok(CudaGraphHandle { + graph, + device: device.clone(), + captured_shape: shape.clone(), + }) + } + + pub fn launch(&mut self) -> Result<()> { + match &mut self.graph { + Some(g) => { + g.launch().map_err(|e| CudaError::from(e)) + } + None => Err(CudaError::InternalError("failed to create graph during capture")) + }.map_err(|e| Error::from(e)) + } + + pub fn shape(&self) -> &Shape { + &self.captured_shape + } + + pub fn device(&self) -> &CudaDevice { &self.device } +} + +impl Drop for CudaGraphHandle { + fn drop(&mut self) { + if !self.graph.is_none() { + std::mem::drop(self.graph.take()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Device, DType, Tensor}; + use crate::backend::BackendDevice; + + // Helper function to create a CUDA device for testing + fn create_cuda_device() -> Result { + CudaDevice::new_with_stream(0) + } + + #[test] + fn test_cuda_graph_capture_basic() -> Result<()> { + + let device = create_cuda_device()?; + let shape = Shape::from_dims(&[2, 3]); + + // Create a simple capture function that performs a basic operation + let graph = CudaGraphHandle::capture(&device, &shape, || { + // In a real scenario, this would contain actual CUDA operations + // For now, we just test that the capture mechanism works + Ok(()) + })?; + + // Verify the shape was captured correctly + assert_eq!(graph.shape(), &shape); + Ok(()) + } + + #[test] + fn test_cuda_graph_launch() -> Result<()> { + let device = create_cuda_device()?; + let shape = Shape::from_dims(&[4, 4]); + + // Capture a graph + let mut graph = CudaGraphHandle::capture(&device, &shape, || { + // Simulate some work during capture + Ok(()) + })?; + + // Test that we can launch the graph without errors + graph.launch()?; + Ok(()) + } + + #[test] + fn test_cuda_graph_multiple_launches() -> Result<()> { + let device = create_cuda_device()?; + let shape = Shape::from_dims(&[8, 8]); + + let mut graph = CudaGraphHandle::capture(&device, &shape, || { + Ok(()) + })?; + + // Test multiple launches of the same graph + for _ in 0..5 { + graph.launch()?; + } + Ok(()) + } + + + #[test] + fn test_cuda_graph_error_handling() -> Result<()> { + let device = create_cuda_device()?; + let shape = Shape::from_dims(&[5, 5]); + + // Test that errors in capture function are properly propagated + let result = CudaGraphHandle::capture(&device, &shape, || { + Err(crate::Error::Msg("Test error during capture".into())) + }); + + assert!(result.is_err()); + Ok(()) + } + + #[test] + fn test_cuda_graph_with_multiple_operations() -> Result<()> { + let device = create_cuda_device()?; + let cuda_device = Device::Cuda(device.clone()); + let shape = Shape::from_dims(&[16, 16]); + + // Create multiple tensors for complex operations + let a = Tensor::ones((16, 16), DType::F32, &cuda_device)?; + let b = Tensor::full(2.0f32, (16, 16), &cuda_device)?; + + // Capture a graph with multiple chained operations + let mut graph = CudaGraphHandle::capture(&device, &shape, || { + // Chain multiple operations together + let step1 = a.add(&b)?; // 1 + 2 = 3 + let step3 = step1.sum_keepdim(1)?; // Sum along axis 1 + let _final_result = step3.relu()?; // Apply ReLU + Ok(()) + })?; + + // Verify graph creation and launch it + assert_eq!(graph.shape(), &shape); + + // Test multiple launches to ensure stability + for _ in 0..5 { + graph.launch()?; + device.synchronize()?; + } + + Ok(()) + } +} \ No newline at end of file diff --git a/candle-core/src/cuda_backend/graph_module.rs b/candle-core/src/cuda_backend/graph_module.rs new file mode 100644 index 0000000000..b9d4096fef --- /dev/null +++ b/candle-core/src/cuda_backend/graph_module.rs @@ -0,0 +1,238 @@ +use std::cell::RefCell; +use crate::{Result, Tensor, Shape, Module, CudaDevice}; +use crate::cuda_backend::graph::CudaGraphHandle; +use std::collections::HashMap; + +pub trait CudaGraphModule: Module { + /// Replay the captured graph for the given input + fn replay(&self, xs: &Tensor) -> Result; + + /// Check if a graph is captured for the given shape + fn has_graph_for_shape(&self, shape: &Shape) -> bool; + + /// Clear captured graphs (useful for memory management) + fn clear_graphs(&mut self); +} + +pub struct CudaGraphWrapper { + module: M, + captured_graphs: RefCell>, + warmup_iterations: usize, +} + +impl CudaGraphWrapper { + pub fn new(module: M) -> Self { + Self { + module, + captured_graphs: RefCell::new(HashMap::new()), + warmup_iterations: 3, // Default warmup iterations + } + } + + /// Automatically capture graph for given input if not already captured + pub fn capture_if_needed(&mut self, input: &Tensor) -> Result<()> { + let shape = input.shape().clone(); + + // Check if already captured + { + let graphs = self.captured_graphs.borrow_mut(); + if graphs.contains_key(&shape) { + return Ok(()); + } + } + + // Perform warmup iterations + for _ in 0..self.warmup_iterations { + let _ = self.module.forward(input)?; + } + + // Capture the graph + let device = CudaDevice::new_with_stream(0)?; + let graph = CudaGraphHandle::capture(&device, &shape, || { + self.module.forward(input)?; + Ok(()) + })?; + self.captured_graphs.borrow_mut().insert(shape, graph); + + Ok(()) + } +} + +impl Module for CudaGraphWrapper { + fn forward(&self, xs: &Tensor) -> Result { + let shape = xs.shape(); + + // Try to use captured graph first + if let Some(graph) = self.captured_graphs.borrow_mut().get_mut(&shape) { + return self.replay_with_graph(xs, graph); + } + + // Fallback + self.module.forward(xs) + } +} + +impl CudaGraphWrapper { + fn replay_with_graph(&self, xs: &Tensor, graph: &mut CudaGraphHandle) -> Result { + + graph.launch()?; + + // For now, fallback to regular execution + // In a full implementation, you'd need to track input/output tensor locations + // and update them appropriately + self.module.forward(xs) + } +} + +impl CudaGraphModule for CudaGraphWrapper { + fn replay(&self, xs: &Tensor) -> Result { + let shape = xs.shape(); + let mut graphs = self.captured_graphs.borrow_mut(); + + if let Some(graph) = graphs.get_mut(shape) { + self.replay_with_graph(xs, graph) + } else { + Err(crate::Error::Msg("No captured graph for this shape".into()).bt()) + } + } + + fn has_graph_for_shape(&self, shape: &Shape) -> bool { + let graphs = self.captured_graphs.borrow(); + graphs.contains_key(shape) + } + + fn clear_graphs(&mut self) { + let mut graphs = self.captured_graphs.borrow_mut(); + graphs.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Device, DType, Tensor}; + + + /// Simple Transformer Block implementation for testing + pub struct MockTransformerBlock { + device: Device, + hidden_size: usize, + } + + impl MockTransformerBlock { + pub fn new(device: Device, hidden_size: usize) -> Result { + Ok(Self { + device, + hidden_size, + }) + } + + fn layer_norm(&self, x: &Tensor) -> Result { + // Simplified layer normalization: just normalize along last dimension + let mean = x.mean_keepdim(x.dims().len() - 1)?; + let variance = x.var_keepdim(x.dims().len() - 1)?; + let normalized = x.broadcast_sub(&mean)?.broadcast_div(&(variance + 1e-5)?)?; + Ok(normalized) + } + + fn multi_head_attention(&self, x: &Tensor) -> Result { + // Simplified self-attention: just a linear transformation + // In practice, this would involve Q, K, V projections and attention computation + let batch_size = x.dim(0)?; + let seq_len = x.dim(1)?; + + // Create a simple "attention" weight matrix + let attention_weights = Tensor::ones((self.hidden_size, self.hidden_size), DType::F32, &self.device)?; + + // Apply attention (simplified as matrix multiplication) + let reshaped = x.reshape((batch_size * seq_len, self.hidden_size))?; + let attended = reshaped.matmul(&attention_weights)?; + let output = attended.reshape((batch_size, seq_len, self.hidden_size))?; + + Ok(output) + } + + fn mlp(&self, x: &Tensor) -> Result { + // Simplified MLP: Linear -> ReLU -> Linear + let batch_size = x.dim(0)?; + let seq_len = x.dim(1)?; + let intermediate_size = self.hidden_size * 4; + + // First linear layer (expand) + let w1 = Tensor::randn(0f32, 0.1, (self.hidden_size, intermediate_size), &self.device)?; + let reshaped = x.reshape((batch_size * seq_len, self.hidden_size))?; + let hidden = reshaped.matmul(&w1)?.relu()?; + + // Second linear layer (project back) + let w2 = Tensor::randn(0f32, 0.1, (intermediate_size, self.hidden_size), &self.device)?; + let output = hidden.matmul(&w2)?; + let final_output = output.reshape((batch_size, seq_len, self.hidden_size))?; + + Ok(final_output) + } + } + + impl Module for MockTransformerBlock { + fn forward(&self, xs: &Tensor) -> Result { + // Standard Transformer block: LayerNorm -> Attention -> Residual -> LayerNorm -> MLP -> Residual + + // Pre-attention layer norm + let normed1 = self.layer_norm(xs)?; + + // Multi-head attention with residual connection + let attended = self.multi_head_attention(&normed1)?; + let residual1 = xs.add(&attended)?; + + // Pre-MLP layer norm + let normed2 = self.layer_norm(&residual1)?; + + // MLP with residual connection + let mlp_output = self.mlp(&normed2)?; + let final_output = residual1.add(&mlp_output)?; + + Ok(final_output) + } + } + + #[test] + fn test_multiple_shapes_capture() -> Result<()> { + let device = Device::Cuda(CudaDevice::new_with_stream(0)?); + let hidden_size = 128; + + let model = MockTransformerBlock::new(device.clone(), hidden_size)?; + let mut model_with_graph = CudaGraphWrapper::new(model); + + // Test with different input shapes + let shapes = vec![ + (2, 32, hidden_size), // Small batch, short sequence + (4, 64, hidden_size), // Medium batch, medium sequence + (1, 128, hidden_size), // Single batch, long sequence + ]; + + for (batch_size, seq_len, hidden_size) in shapes { + let input = Tensor::randn(0f32, 1.0, (batch_size, seq_len, hidden_size), &device)?; + let shape = input.shape().clone(); + + // Capture graph for this shape + model_with_graph.capture_if_needed(&input)?; + + // Verify capture + assert!(model_with_graph.has_graph_for_shape(&shape)); + + // Test replay + let _output = model_with_graph.replay(&input)?; + + } + + // Test clearing graphs + model_with_graph.clear_graphs(); + + // Verify all graphs are cleared + for (batch_size, seq_len, hidden_size) in [(2, 32, 128), (4, 64, 128), (1, 128, 128)] { + let shape = Shape::from_dims(&[batch_size, seq_len, hidden_size]); + assert!(!model_with_graph.has_graph_for_shape(&shape)); + } + + Ok(()) + } +} \ No newline at end of file diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index b1f166a6ac..450d620151 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -17,9 +17,13 @@ pub mod cudnn; mod device; mod error; mod utils; +mod graph; +mod graph_module; + pub use device::{CudaDevice, DeviceId}; pub use error::{CudaError, WrapErr}; pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S}; +pub use graph::CudaGraphHandle; pub enum SlicePtrOrNull { Ptr(CudaSlice), diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index e6fcc05a73..30a2428c3b 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -3,6 +3,7 @@ use crate::{Error, Result}; #[derive(Clone, PartialEq, Eq)] +#[derive(Hash)] pub struct Shape(Vec); pub const SCALAR: Shape = Shape(vec![]);