Skip to content

Commit 262c7cf

Browse files
authored
Add basic convolution functionalities for forward and backward passes (#29)
1 parent 437ad90 commit 262c7cf

20 files changed

+1545
-441
lines changed

crates/cudnn/src/context.rs

Lines changed: 735 additions & 68 deletions
Large diffs are not rendered by default.

crates/cudnn/src/convolution/convolution_algo.rs

Lines changed: 695 additions & 0 deletions
Large diffs are not rendered by default.

crates/cudnn/src/convolution_descriptor.rs renamed to crates/cudnn/src/convolution/convolution_descriptor.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
use crate::{
2-
convolution_mode::ConvolutionMode,
2+
convolution::ConvolutionMode,
33
data_type::DataType,
44
error::{CudnnError, IntoResult},
55
math_type::MathType,
66
sys,
7-
tensor_format::TensorFormat,
7+
tensor::TensorFormat,
88
};
99

1010
use std::{marker::PhantomData, mem::MaybeUninit};
1111

1212
/// A generic description of an n-dimensional convolution.
1313
///
14-
/// *Do note* that N can be either 2 or 3, respectively for a 2-d or a 3-d convolution.
15-
#[derive(Debug, Clone, PartialEq, Hash)]
14+
/// **Do note** that N can be either 2 or 3, respectively for a 2-d or a 3-d convolution, and that
15+
/// the same convolution descriptor can be reused in the backward path provided it corresponds to
16+
/// the same layer.
17+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1618
pub struct ConvolutionDescriptor<T: DataType, const N: usize> {
1719
pub(crate) raw: sys::cudnnConvolutionDescriptor_t,
1820
comp_type: PhantomData<T>,
@@ -131,6 +133,12 @@ impl<T: DataType, const N: usize> ConvolutionDescriptor<T, N> {
131133
}
132134

133135
/// Sets the `MathType` for this convolution descriptor instance.
136+
///
137+
/// # Arguments
138+
///
139+
/// `math_type` - the provided math type.
140+
///
141+
/// **Do note** that tensor core operations may not be available on all device architectures.
134142
pub fn set_math_type(&mut self, math_type: MathType) -> Result<(), CudnnError> {
135143
unsafe { sys::cudnnSetConvolutionMathType(self.raw, math_type.into()).into_result() }
136144
}

crates/cudnn/src/convolution_mode.rs renamed to crates/cudnn/src/convolution/convolution_mode.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
use crate::sys;
22

3-
/// Enum used to configure a convolution descriptor. The filter used for the convolution can be
4-
/// applied in two different ways, corresponding mathematically to a convolution or to a
5-
/// cross-correlation. A cross-correlation is equivalent to a convolution with its filter
6-
/// rotated by 180 degrees.
3+
/// Enum used to configure a convolution descriptor.
4+
///
5+
/// The filter used for the convolution can be applied in two different ways, corresponding
6+
/// mathematically to a convolution or to a cross-correlation.
7+
///
8+
/// A cross-correlation is equivalent to a convolution with its filter rotated by 180 degrees.
79
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
810
pub enum ConvolutionMode {
911
/// Convolution operation.
@@ -22,7 +24,6 @@ impl From<sys::cudnnConvolutionMode_t> for ConvolutionMode {
2224
}
2325

2426
impl From<ConvolutionMode> for sys::cudnnConvolutionMode_t {
25-
/// Returns the raw cuDNN type associated to the given variant.
2627
fn from(convolution_mode: ConvolutionMode) -> sys::cudnnConvolutionMode_t {
2728
match convolution_mode {
2829
ConvolutionMode::Convolution => sys::cudnnConvolutionMode_t::CUDNN_CONVOLUTION,

crates/cudnn/src/filter_descriptor.rs renamed to crates/cudnn/src/convolution/filter/filter_descriptor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{
22
data_type::DataType,
33
error::{CudnnError, IntoResult},
44
sys,
5-
tensor_format::{SupportedType, TensorFormat},
5+
tensor::{SupportedType, TensorFormat},
66
};
77
use std::{
88
marker::PhantomData,

crates/cudnn/src/filter.rs renamed to crates/cudnn/src/convolution/filter/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
mod filter_descriptor;
2+
pub use filter_descriptor::*;
3+
14
use crate::{
25
data_type::DataType,
36
error::CudnnError,
4-
filter_descriptor::FilterDescriptor,
5-
tensor_format::{SupportedType, TensorFormat},
7+
tensor::{SupportedType, TensorFormat},
68
};
79
use cust::memory::{DeviceBuffer, DeviceCopy, GpuBox, GpuBuffer, UnifiedBuffer};
810

crates/cudnn/src/convolution/mod.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
mod convolution_algo;
2+
mod convolution_descriptor;
3+
mod convolution_mode;
4+
mod filter;
5+
6+
pub use convolution_algo::*;
7+
pub use convolution_descriptor::*;
8+
pub use convolution_mode::*;
9+
pub use filter::*;

crates/cudnn/src/convolution_fwd_algo.rs

Lines changed: 0 additions & 301 deletions
This file was deleted.

0 commit comments

Comments
 (0)