Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features = false }
cudarc = { version = "0.16.3", features = [
cudarc = { version = "0.17.3", features = [
"std",
"cublas",
"cublaslt",
Expand All @@ -62,7 +62,7 @@ half = { version = "2.5.0", features = [
"use-intrinsics",
"rand_distr",
] }
float8 = { git = "https://github.com/zackangelo/float8", branch = "cudarc_0_16", features = [
float8 = { version = "0.4.2", features = [
"num-traits",
"rand_distr",
] }
Expand Down
191 changes: 191 additions & 0 deletions candle-core/src/cuda_backend/graph.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
use std::fmt;
use cudarc::driver::{CudaGraph, sys};
use crate::{CudaDevice, Result, Shape, Error};
use crate::cuda_backend::CudaError;

pub struct CudaGraphHandle {
graph: Option<CudaGraph>,
device: CudaDevice,
captured_shape: Shape,
}

impl fmt::Debug for CudaGraphHandle {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
#[derive(Debug)]
#[allow(dead_code)]
struct CudaGraphHandle<'a> {
device: &'a CudaDevice,
captured_shape: &'a Shape,
}

let Self {
graph: _,
device,
captured_shape,
} = self;

fmt::Debug::fmt(&CudaGraphHandle { device, captured_shape }, f)
}
}

impl CudaGraphHandle {
pub fn capture<F>(
device: &CudaDevice,
shape: &Shape,
capture_fn: F
) -> Result<Self>
where
F: FnOnce() -> Result<()>
{
let stream = device.cuda_stream();

// Begin capture
let flags = sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_RELAXED;

stream.begin_capture(flags).map_err(|e| CudaError::from(e))?;

// Execute the capture function
capture_fn()?;

// End capture and create graph
let flags = sys::CUgraphInstantiate_flags_enum::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH;
let graph = stream.end_capture(flags).map_err(|e| CudaError::from(e))?;

Ok(CudaGraphHandle {
graph,
device: device.clone(),
captured_shape: shape.clone(),
})
}

pub fn launch(&mut self) -> Result<()> {
match &mut self.graph {
Some(g) => {
g.launch().map_err(|e| CudaError::from(e))
}
None => Err(CudaError::InternalError("failed to create graph during capture"))
}.map_err(|e| Error::from(e))
}

pub fn shape(&self) -> &Shape {
&self.captured_shape
}

pub fn device(&self) -> &CudaDevice { &self.device }
}

impl Drop for CudaGraphHandle {
fn drop(&mut self) {
if !self.graph.is_none() {
std::mem::drop(self.graph.take());
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{Device, DType, Tensor};
use crate::backend::BackendDevice;

// Helper function to create a CUDA device for testing
fn create_cuda_device() -> Result<CudaDevice> {
CudaDevice::new_with_stream(0)
}

#[test]
fn test_cuda_graph_capture_basic() -> Result<()> {

let device = create_cuda_device()?;
let shape = Shape::from_dims(&[2, 3]);

// Create a simple capture function that performs a basic operation
let graph = CudaGraphHandle::capture(&device, &shape, || {
// In a real scenario, this would contain actual CUDA operations
// For now, we just test that the capture mechanism works
Ok(())
})?;

// Verify the shape was captured correctly
assert_eq!(graph.shape(), &shape);
Ok(())
}

#[test]
fn test_cuda_graph_launch() -> Result<()> {
let device = create_cuda_device()?;
let shape = Shape::from_dims(&[4, 4]);

// Capture a graph
let mut graph = CudaGraphHandle::capture(&device, &shape, || {
// Simulate some work during capture
Ok(())
})?;

// Test that we can launch the graph without errors
graph.launch()?;
Ok(())
}

#[test]
fn test_cuda_graph_multiple_launches() -> Result<()> {
let device = create_cuda_device()?;
let shape = Shape::from_dims(&[8, 8]);

let mut graph = CudaGraphHandle::capture(&device, &shape, || {
Ok(())
})?;

// Test multiple launches of the same graph
for _ in 0..5 {
graph.launch()?;
}
Ok(())
}


#[test]
fn test_cuda_graph_error_handling() -> Result<()> {
let device = create_cuda_device()?;
let shape = Shape::from_dims(&[5, 5]);

// Test that errors in capture function are properly propagated
let result = CudaGraphHandle::capture(&device, &shape, || {
Err(crate::Error::Msg("Test error during capture".into()))
});

assert!(result.is_err());
Ok(())
}

#[test]
fn test_cuda_graph_with_multiple_operations() -> Result<()> {
let device = create_cuda_device()?;
let cuda_device = Device::Cuda(device.clone());
let shape = Shape::from_dims(&[16, 16]);

// Create multiple tensors for complex operations
let a = Tensor::ones((16, 16), DType::F32, &cuda_device)?;
let b = Tensor::full(2.0f32, (16, 16), &cuda_device)?;

// Capture a graph with multiple chained operations
let mut graph = CudaGraphHandle::capture(&device, &shape, || {
// Chain multiple operations together
let step1 = a.add(&b)?; // 1 + 2 = 3
let step3 = step1.sum_keepdim(1)?; // Sum along axis 1
let _final_result = step3.relu()?; // Apply ReLU
Ok(())
})?;

// Verify graph creation and launch it
assert_eq!(graph.shape(), &shape);

// Test multiple launches to ensure stability
for _ in 0..5 {
graph.launch()?;
device.synchronize()?;
}

Ok(())
}
}
Loading
Loading