Skip to content

Commit 56db7b4

Browse files
authored
Merge pull request #3 from TheMesocarp/feat/diffusion!
diffusion!
2 parents 3e2b1cc + bf439e6 commit 56db7b4

File tree

9 files changed

+200
-139
lines changed

9 files changed

+200
-139
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ edition = "2021"
66
[dependencies]
77
num-complex = "0.4.6"
88
rand = "0.9.0"
9-
candle-core = "0.9.0"
9+
candle-core = "0.9.1"
10+
candle-optimisers = "0.9.0"
11+
candle-nn = "0.9.1"

src/cw.rs

Lines changed: 0 additions & 116 deletions
This file was deleted.

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
pub mod error;
22
pub mod math;
3-
pub mod sheaf;
3+
pub mod nn;

src/math/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub mod cell;
2+
pub mod sheaf;
23
pub mod tensors;

src/sheaf.rs renamed to src/math/sheaf.rs

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
#[derive(PartialEq, Eq, Hash, Clone)]
1313
pub struct Point<T: Eq + std::hash::Hash + Clone + Sized>(T);
1414

15-
pub struct Section(Vector);
15+
pub struct Section(pub Vector);
1616

1717
impl Section {
1818
pub fn new<T: WithDType>(
@@ -30,8 +30,8 @@ pub struct CellularSheaf<O: OpenSet> {
3030
pub restrictions: HashMap<(usize, usize, usize, usize), Matrix>,
3131
pub interlinked: HashMap<(usize, usize, usize, usize), i8>,
3232
pub global_sections: Vec<Section>,
33-
device: Device,
34-
dtype: DType,
33+
pub device: Device,
34+
pub dtype: DType,
3535
}
3636

3737
impl<O: OpenSet> CellularSheaf<O> {
@@ -94,7 +94,7 @@ impl<O: OpenSet> CellularSheaf<O> {
9494
Ok(())
9595
}
9696

97-
pub fn k_coboundary(
97+
fn k_coboundary(
9898
&self,
9999
k: usize,
100100
k_cochain: Vec<Vector>,
@@ -151,7 +151,7 @@ impl<O: OpenSet> CellularSheaf<O> {
151151
}
152152

153153
/// Computes the adjoint of the k-th coboundary operator ((delta^k)*).
154-
pub fn k_adjoint_coboundary(
154+
fn k_adjoint_coboundary(
155155
&self,
156156
k: usize,
157157
k_coboundary_output: Vec<Vector>,
@@ -215,34 +215,38 @@ impl<O: OpenSet> CellularSheaf<O> {
215215
}
216216

217217
/// Retrieves the cochain (Vec<Vector>) for a given dimension k.
218-
pub fn get_k_cochain(&self, k: usize) -> Result<Vec<Vector>, MathError> {
218+
pub fn get_k_cochain(&self, k: usize) -> Result<Matrix, MathError> {
219219
if k >= self.section_spaces.len() {
220220
return Err(MathError::DimensionMismatch);
221221
}
222222
let k_sections = &self.section_spaces[k];
223223
let k_cochain: Vec<Vector> = k_sections.iter().map(|section| section.0.clone()).collect();
224-
225-
Ok(k_cochain)
224+
Matrix::from_vecs(k_cochain).map_err(MathError::Candle)
226225
}
227226

228227
pub fn k_hodge_laplacian(
229228
&self,
230229
k: usize,
231230
k_cochain: Matrix,
232-
k_stalk_dim: usize,
233-
k_plus_stalk_dim: usize,
234-
k_minus_stalk_dim: usize,
231+
down_included: bool,
235232
) -> Result<Matrix, MathError> {
236233
let vecs = k_cochain.to_vectors().map_err(MathError::Candle)?;
234+
235+
let k_plus_stalk_dim = self.section_spaces[k + 1][0].0.dimension();
236+
let k_stalk_dim = self.section_spaces[k][0].0.dimension();
237+
let k_minus_stalk_dim = self.section_spaces[k - 1][0].0.dimension();
238+
237239
let up_a = self.k_coboundary(k, vecs.clone(), k_plus_stalk_dim)?;
238240
let up_b = self.k_adjoint_coboundary(k, up_a, k_stalk_dim)?;
239-
240-
let down_a = self.k_adjoint_coboundary(k, vecs, k_minus_stalk_dim)?;
241-
let down_b = self.k_coboundary(k, down_a, k_stalk_dim)?;
242-
let out = Matrix::from_vecs(up_b)
243-
.map_err(MathError::Candle)?
244-
.add(&Matrix::from_vecs(down_b).map_err(MathError::Candle)?)
245-
.map_err(MathError::Candle)?;
246-
Ok(out)
241+
if down_included {
242+
let down_a = self.k_adjoint_coboundary(k, vecs, k_minus_stalk_dim)?;
243+
let down_b = self.k_coboundary(k, down_a, k_stalk_dim)?;
244+
let out = Matrix::from_vecs(up_b)
245+
.map_err(MathError::Candle)?
246+
.add(&Matrix::from_vecs(down_b).map_err(MathError::Candle)?)
247+
.map_err(MathError::Candle)?;
248+
return Ok(out);
249+
}
250+
Matrix::from_vecs(up_b).map_err(MathError::Candle)
247251
}
248252
}

src/math/tensors.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ impl Vector {
8282
#[derive(Debug, Clone)]
8383
pub struct Matrix {
8484
pub tensor: Tensor,
85-
device: Device,
86-
dtype: DType,
85+
pub device: Device,
86+
pub dtype: DType,
8787
}
8888

8989
impl Matrix {
@@ -263,6 +263,12 @@ impl Matrix {
263263

264264
Ok(cols_vectors)
265265
}
266+
267+
/// Generates a new random matrix with elements sampled from a standard normal distribution (mean 0, std dev 1).
268+
pub fn rand(rows: usize, cols: usize, device: Device, dtype: DType) -> Result<Self> {
269+
let tensor = Tensor::randn(0.0f32, 1.0f32, (rows, cols), &device)?.to_dtype(dtype)?;
270+
Self::new(tensor, device, dtype)
271+
}
266272
}
267273

268274
// Example Usage (requires a candle_core setup)

src/nn/activations.rs

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
use candle_core::{Error, Tensor};
2+
3+
use crate::math::tensors::Matrix;
4+
5+
pub enum Activations {
6+
Step,
7+
Linear,
8+
Sigmoid,
9+
Tanh,
10+
ReLU,
11+
Softmax,
12+
Swish,
13+
GeLU,
14+
Sinc,
15+
SeLU,
16+
}
17+
18+
impl Activations {
19+
pub fn activate(&self, input: Matrix) -> Result<Matrix, Error> {
20+
let tensor = input.inner();
21+
let device = input.device.clone();
22+
let dtype = input.dtype;
23+
24+
let activated = match self {
25+
Activations::Step => {
26+
// Heaviside step function: 1.0 where x >= 0.0, else 0.0
27+
let zeros = Tensor::zeros_like(tensor)?;
28+
let ones = Tensor::ones_like(tensor)?;
29+
tensor.ge(&zeros)?.where_cond(&ones, &zeros)?
30+
}
31+
Activations::Linear => tensor.clone(),
32+
Activations::Tanh => tensor.tanh()?,
33+
Activations::ReLU => tensor.relu()?,
34+
Activations::Sinc => {
35+
// sinc(x) = sin(x) / x, define 1 at x=0
36+
// Using a small epsilon to handle division by zero.
37+
// If x is near zero, output 1, else sin(x)/x
38+
let eps_val = 1e-7f64;
39+
let eps = Tensor::full(eps_val, tensor.dims(), &device)?.to_dtype(dtype)?;
40+
let near_zero = tensor.abs()?.le(&eps)?;
41+
42+
let numerator = tensor.sin()?;
43+
let denominator = tensor.clone(); // Clone to avoid consuming tensor
44+
let value = numerator.div(&denominator)?;
45+
46+
near_zero.where_cond(&Tensor::ones_like(tensor)?, &value)?
47+
}
48+
Activations::Sigmoid => {
49+
// Sigmoid(x) = 1 / (1 + exp(-x))
50+
let neg_x = tensor.neg()?;
51+
let exp_neg_x = neg_x.exp()?;
52+
let one = Tensor::ones_like(&exp_neg_x)?;
53+
let one_plus_exp_neg_x = one.add(&exp_neg_x)?;
54+
one_plus_exp_neg_x.recip()? // 1 / (1 + exp(-x))
55+
}
56+
Activations::Softmax => {
57+
// Softmax(x_i) = exp(x_i) / sum(exp(x_j)) along the last dimension
58+
// For a Matrix (rank 2), apply along dim 1 (columns) for each row.
59+
let exp_x = tensor.exp()?;
60+
// Sum along the last dimension, keeping the dimension for broadcasting
61+
let sum_exp_x = exp_x.sum_keepdim(1)?;
62+
exp_x.broadcast_div(&sum_exp_x)?
63+
}
64+
Activations::Swish => {
65+
// Swish(x) = x * Sigmoid(x)
66+
let neg_x = tensor.neg()?;
67+
let exp_neg_x = neg_x.exp()?;
68+
let one = Tensor::ones_like(&exp_neg_x)?;
69+
let one_plus_exp_neg_x = one.add(&exp_neg_x)?;
70+
let sigmoid_x = one_plus_exp_neg_x.recip()?;
71+
tensor.mul(&sigmoid_x)?
72+
}
73+
Activations::GeLU => {
74+
// GeLU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
75+
let sqrt_two_val = 2.0f64.sqrt();
76+
let sqrt_two =
77+
Tensor::full(sqrt_two_val, tensor.dims(), &device)?.to_dtype(dtype)?;
78+
79+
let x_div_sqrt_two = tensor.div(&sqrt_two)?;
80+
let erf_val = x_div_sqrt_two.erf()?;
81+
let one = Tensor::ones_like(&erf_val)?;
82+
let one_plus_erf = one.add(&erf_val)?;
83+
84+
let half_val = 0.5f64;
85+
let half = Tensor::full(half_val, tensor.dims(), &device)?.to_dtype(dtype)?;
86+
87+
tensor.mul(&half)?.mul(&one_plus_erf)?
88+
}
89+
Activations::SeLU => {
90+
// SeLU(x) = lambda * (x if x > 0 else alpha * (exp(x) - 1))
91+
// Standard constants for SeLU
92+
let alpha_val = 1.673_263_242_354_377_2_f64;
93+
let lambda_val = 1.050_700_987_355_480_5_f64;
94+
95+
let alpha = Tensor::full(alpha_val, tensor.dims(), &device)?.to_dtype(dtype)?;
96+
let lambda = Tensor::full(lambda_val, tensor.dims(), &device)?.to_dtype(dtype)?;
97+
let zero = Tensor::zeros_like(tensor)?;
98+
99+
// Condition: x > 0
100+
let cond_gt_zero = tensor.gt(&zero)?;
101+
102+
// Case for x > 0: just x
103+
let case_gt_zero = tensor.clone();
104+
105+
// Case for x <= 0: alpha * (exp(x) - 1)
106+
let exp_x = tensor.exp()?;
107+
let one_for_sub = Tensor::ones_like(&exp_x)?;
108+
let exp_x_minus_one = exp_x.sub(&one_for_sub)?;
109+
let case_le_zero = alpha.mul(&exp_x_minus_one)?;
110+
111+
let result = cond_gt_zero.where_cond(&case_gt_zero, &case_le_zero)?;
112+
lambda.mul(&result)?
113+
}
114+
};
115+
116+
Matrix::new(activated, device, dtype)
117+
}
118+
}

0 commit comments

Comments
 (0)