Skip to content

Commit 92f16ea

Browse files
authored
Merge pull request #21 from TheMesocarp/publish
change integration tests and publish
2 parents 6b3bd78 + c39cbcb commit 92f16ea

File tree

3 files changed

+81
-9
lines changed

3 files changed

+81
-9
lines changed

.github/workflows/rust.yaml

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
- uses: actions/checkout@v4
2121

2222
- name: Install Rust
23-
uses: dtolnay/rust-toolchain@nightly
23+
uses: dtolnay/rust-toolchain@stable
2424
with:
2525
components: rustfmt
2626

@@ -39,7 +39,7 @@ jobs:
3939
- uses: actions/checkout@v4
4040

4141
- name: Install Rust
42-
uses: dtolnay/rust-toolchain@nightly
42+
uses: dtolnay/rust-toolchain@stable
4343

4444
- name: Rust Cache
4545
uses: Swatinem/rust-cache@v2
@@ -56,7 +56,7 @@ jobs:
5656
- uses: actions/checkout@v4
5757

5858
- name: Install Rust
59-
uses: dtolnay/rust-toolchain@nightly
59+
uses: dtolnay/rust-toolchain@stable
6060
with:
6161
components: clippy
6262

@@ -78,12 +78,61 @@ jobs:
7878
- uses: actions/checkout@v4
7979

8080
- name: Install Rust
81-
uses: dtolnay/rust-toolchain@nightly
81+
uses: dtolnay/rust-toolchain@stable
8282

8383
- name: Rust Cache
8484
uses: Swatinem/rust-cache@v2
8585
with:
8686
key: rust/test
8787

8888
- name: Run tests
89-
run: cargo test --verbose --workspace
89+
run: cargo test --verbose --workspace
90+
91+
toml-fmt:
92+
name: taplo
93+
runs-on: ubuntu-latest
94+
steps:
95+
- uses: actions/checkout@v4
96+
97+
- name: Install Rust
98+
uses: dtolnay/rust-toolchain@stable
99+
100+
- name: Install taplo
101+
uses: taiki-e/install-action@cargo-binstall
102+
with:
103+
tool: taplo-cli
104+
105+
- name: Rust Cache
106+
uses: Swatinem/rust-cache@v2
107+
with:
108+
key: rust/taplo
109+
110+
- name: Run TOML fmt
111+
run: taplo fmt --check
112+
113+
# semver:
114+
# name: semver
115+
# runs-on: ubuntu-latest
116+
# continue-on-error: true
117+
# steps:
118+
# - uses: actions/checkout@v4
119+
# with:
120+
# fetch-depth: 0
121+
122+
# - name: Install Rust
123+
# uses: dtolnay/rust-toolchain@master
124+
# with:
125+
# toolchain: nightly-2025-05-25
126+
127+
# - name: Rust Cache
128+
# uses: Swatinem/rust-cache@v2
129+
# with:
130+
# key: rust/semver
131+
132+
# - name: Install cargo-semver-checks
133+
# uses: taiki-e/install-action@cargo-binstall
134+
# with:
135+
# tool: cargo-semver-checks
136+
137+
# - name: Check semver
138+
# run: cargo semver-checks check-release

Cargo.toml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
[package]
2+
3+
authors = ["Mesocarp"]
4+
description = "A deep learning model for spectral diffusion over k-cells in arbitrary cell complexes, built on `candle`"
5+
edition = "2021"
6+
homepage = "https://github.com/TheMesocarp/koho"
7+
keywords = ["ml", "deep_learning", "spectral_diffusion", "sheaves", "TDA"]
8+
license = "AGPL-3.0"
29
name = "koho"
10+
readme = "README.md"
11+
repository = "https://github.com/TheMesocarp/koho"
312
version = "0.1.0"
4-
edition = "2021"
13+
514

615
[dependencies]
716
thiserror = "2.0.12"
8-
num-complex = "0.4.6"
917
rand = "0.9.0"
1018
candle-core = "0.9.1"
11-
candle-optimisers = { git = "https://github.com/KGrewal1/optimisers.git", version = "0.9.0" }
1219
candle-nn = "0.9.1"

src/nn/optim/mod.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use candle_core::{backprop::GradStore, Result as CandleResult, Var};
2-
use candle_optimisers::{Decay, Momentum};
32

43
pub use crate::nn::optim::adam::Adam;
54
pub use crate::nn::optim::rmsprop::RMSprop;
@@ -9,6 +8,23 @@ mod adam;
98
mod rmsprop;
109
mod sgd;
1110

11+
#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
12+
pub enum Decay {
13+
/// Weight decay regularisation to penalise large weights
14+
WeightDecay(f64),
15+
/// Decoupled weight decay as described in [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101)
16+
DecoupledWeightDecay(f64),
17+
}
18+
19+
/// Type of momentum to use
20+
#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
21+
pub enum Momentum {
22+
/// classical momentum
23+
Classical(f64),
24+
/// nesterov momentum
25+
Nesterov(f64),
26+
}
27+
1228
/// Trait for optimizers that work with mutable references to variables
1329
pub trait Optimizer {
1430
/// Perform one optimization step using the gradients

0 commit comments

Comments
 (0)