Skip to content

Commit c08e62b

Browse files
Add Docker
1 parent 3db99fd commit c08e62b

File tree

12 files changed

+830
-669
lines changed

12 files changed

+830
-669
lines changed

.dockerignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.venv/

Dockerfile

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
FROM python:3.12-slim-trixie
2+
3+
RUN apt-get update && apt-get install -y git
4+
5+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin
6+
7+
WORKDIR /app
8+
9+
RUN --mount=type=cache,target=/root/.cache/uv \
10+
--mount=type=bind,source=uv.lock,target=uv.lock \
11+
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
12+
uv sync --locked --no-install-project
13+
14+
ADD . /app
15+
16+
RUN --mount=type=cache,target=/root/.cache/uv \
17+
uv sync --locked
18+
19+
# Need to add hsd as a package, but original GitHub does not have pyproject.toml
20+
# Use this as a workaround
21+
RUN mkdir -p /opt/hsd/src/hsd/
22+
RUN git clone https://github.com/AxelElaldi/fast-equivariant-deconv /opt/hsd/src/hsd/
23+
RUN touch /opt/hsd/src/hsd/__init__.py /opt/hsd/src/hsd/utils/__init__.py /opt/hsd/src/hsd/model/__init__.py
24+
RUN cat > /opt/hsd/pyproject.toml <<'PYPROJECT'
25+
[project]
26+
name = "hsd"
27+
version = "0.1.0"
28+
description = "https://github.com/AxelElaldi/fast-equivariant-deconv"
29+
readme = "README.md"
30+
dependencies = [
31+
"h5py==3.11.0",
32+
"healpy==1.16.6",
33+
"joblib==1.4.2",
34+
"matplotlib==3.9.0",
35+
"nibabel==5.2.1",
36+
"numpy<2",
37+
"pandas==2.2.2",
38+
"pygsp==0.5.1",
39+
"pyyaml==6.0.1",
40+
"scipy==1.13.0",
41+
"tensorboard==2.16.2",
42+
"torch==2.3.0",
43+
"torchaudio==2.3.0",
44+
"torchvision==0.18.0",
45+
]
46+
47+
[build-system]
48+
requires = ["setuptools>=61.0"]
49+
build-backend = "setuptools.build_meta"
50+
PYPROJECT
51+
52+
RUN --mount=type=cache,target=/root/.cache/uv \
53+
uv pip install /opt/hsd
54+
55+
ENV PATH="/app/.venv/bin:${PATH}"

README.md

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,59 @@
1-
# DeepFixel
2-
Deep learning-based identification of crossing fiber bundle elements
1+
# DeepFixel: Crossing white matter fiber identification through spherical convolutional neural networks
2+
[![arXiv](https://img.shields.io/badge/arXiv-2511.03893-b31b1b.svg)](https://arxiv.org/abs/2511.03893)
33

4-
## Training and testing the model
4+
DeepFixel is a deep learning method for identification of crossing fiber bundle elements from diffusion MRI.
5+
6+
> Adam M. Saunders, Lucas W. Remedios, Elyssa M. McMaster, Jongyeon Yoon, Gaurav Rudravaram, Adam Sadriddinov, Praitayini Kanakaraj, Bennett A, Landman, and Adam W. Anderson. DeepFixel: Crossing white matter fiber identification through spherical convolutional neural networks. SPIE Medical Imaging: Clinical and Biomedical Imaging, 2026. [https://arxiv.org/abs/2511.03893](https://arxiv.org/abs/2511.03893)
7+
8+
## Installation
59
You can set up an environment using [`uv`](https://github.com/astral-sh/uv) by running the following command:
610
```bash
711
uv sync
812
```
913

10-
To run the model, download the weights and testing dataset from the following link: [https://zenodo.org/records/16587458](https://zenodo.org/records/16587458).
14+
Alternatively, you can use Docker or Apptainer (see instructions below).
1115

16+
## Usage
17+
To run the model, download the weights and testing dataset from the following link: [https://zenodo.org/records/16587458](https://zenodo.org/records/16587458).
1218
- Unzip and copy the testing data to `./test_data`
1319
- Put the weights in `./models/pretrained`
1420

15-
See `run_pretrained_deep_fixel.py` to test the pretrained model and `run_deep_fixel.py` to train and test a new model.
1621

17-
## Using the model
18-
If you wish to apply the model to your own dataset, you can use `fissile.test_mesh_model()` as a basis for your code. You can also use `fissile.dataset.GeneratedMeshNIFTIDataset()` if your data is stored as spherical harmonic coefficients in a NIFTI file.
22+
To train the model:
23+
```bash
24+
python train_deep_fixel.py --config config/example_scnn.yaml
25+
```
26+
27+
To test the model on the provided testing dataset:
28+
```bash
29+
python test_deep_fixel.py --config config/example_scnn.yaml
30+
```
31+
32+
33+
## Usage (Docker)
34+
To build the Docker image, clone the repository and run the following command in the root directory:
35+
```bash
36+
sudo docker build -t spherical_deep_fixel:v1.0.0 .
37+
```
38+
39+
Then run the Docker container with the following command (note you will likely need to bind in local directories with `-v`):
40+
```bash
41+
sudo docker run --rm -it --gpus all -v $(pwd):$(pwd) $spherical_deep_fixel:v1.0.0 python train_deep_fixel.py --config /path/to/config/example_scnn.yaml
42+
sudo docker run --rm -it --gpus all -v $(pwd):$(pwd) $spherical_deep_fixel:v1.0.0 python test_deep_fixel.py --config /path/to/config/example_scnn.yaml
43+
```
44+
45+
## Usage (Apptainer)
46+
A pre-built Apptainer image is available on Zenodo:
47+
48+
```bash
49+
apptainer run -C -B $(pwd):$(pwd) --nv https://zenodo.org/records/16587458/files/spherical_deep_fixel_v1.0.0.sif python /app/train_deep_fixel.py --config /path/to/config/example_scnn.yaml
50+
apptainer run -C -B $(pwd):$(pwd) --nv https://zenodo.org/records/16587458/files/spherical_deep_fixel_v1.0.0.sif python /app/test_deep_fixel.py --config /path/to/config/example_scnn.yaml
51+
```
52+
53+
## Applying the model to your own data
54+
If you wish to apply the model to your own dataset, you can use `fissile.test_mesh_model()` as a basis for your code. You can also use `fissile.dataset.GeneratedMeshNIFTIDataset()` if your data is stored as spherical harmonic coefficients in a NIFTI file.
55+
56+
## Citation
57+
If you use this code in your research, please cite the following paper:
58+
59+
> Adam M. Saunders, Lucas W. Remedios, Elyssa M. McMaster, Jongyeon Yoon, Gaurav Rudravaram, Adam Sadriddinov, Praitayini Kanakaraj, Bennett A, Landman, and Adam W. Anderson. DeepFixel: Crossing white matter fiber identification through spherical convolutional neural networks. SPIE Medical Imaging: Clinical and Biomedical Imaging, 2026. [https://arxiv.org/abs/2511.03893](https://arxiv.org/abs/2511.03893)

config/example_mlp.yaml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Example config for MLP model training and testing
2+
3+
# wandb project name
4+
project_name: "deepfixel"
5+
6+
# Learning rate
7+
lr: 1e-3
8+
9+
# Batch size for training and testing
10+
batch_size: 512
11+
12+
# Loss function, can be "MSE" or "L1"
13+
loss: "MSE"
14+
15+
# Model type, can be "mesh_mlp" or "mesh_scnn"
16+
model: "mesh_mlp"
17+
18+
# GPU for training
19+
gpu_id: 0
20+
21+
# Random seed for reproducibility
22+
seed: 42
23+
24+
# Keep this at 1 for correct sampling of sphere
25+
mesh_subdivide: 1
26+
27+
# Keep this true for correct sampling of sphere
28+
healpix: true
29+
30+
# Spread parameter for Bingham distribution of simulated data
31+
kappa: 100
32+
33+
# Minimum separation angle between fibers in degrees for simulated data
34+
min_separation_angle: 0
35+
36+
# Number of fibers per voxel, either 1, 2 or 'both'
37+
n_fibers: 'both'
38+
39+
# Directories for testing data and testing outputs
40+
test_dir: "./test_data"
41+
output_dir: './outputs/pretrained_mlp'
42+
43+
# Threshold for amplitude of FOD peaks (below this, peaks are ignored)
44+
amp_threshold: 0.1
45+
46+
# Path to pretrained model
47+
pretrained_model_path: "./models/pretrained/best_model_mlp.pth"

config/example_scnn.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Example config for spherical CNN model training and testing
2+
3+
# wandb project name
4+
project_name: "deepfixel"
5+
6+
# Learning rate
7+
lr: 1e-3
8+
9+
# Batch size for training and testing
10+
batch_size: 512
11+
12+
# Loss function, can be "MSE" or "L1"
13+
loss: "MSE"
14+
15+
# Model type, can be "mesh_mlp" or "mesh_scnn"
16+
model: "mesh_scnn"
17+
18+
# GPU for training
19+
gpu_id: 0
20+
21+
# Random seed for reproducibility
22+
seed: 42
23+
24+
# Keep this at 1 for correct sampling of sphere
25+
mesh_subdivide: 1
26+
27+
# Keep this true for correct sampling of sphere
28+
healpix: true
29+
30+
# Spread parameter for Bingham distribution of simulated data
31+
kappa: 100
32+
33+
# Minimum separation angle between fibers in degrees for simulated data
34+
min_separation_angle: 0
35+
36+
# Number of fibers per voxel, either 1, 2 or 'both'
37+
n_fibers: 'both'
38+
39+
### Testing parameters ###
40+
41+
# Directories for testing data and testing outputs
42+
test_dir: "./test_data"
43+
output_dir: './outputs/pretrained_scnn'
44+
45+
# Threshold for amplitude of FOD peaks (below this, peaks are ignored)
46+
amp_threshold: 0.1
47+
48+
# Path to pretrained model
49+
pretrained_model_path: "./models/pretrained/best_model_scnn.pth"

run_deep_fixel.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

run_pretrained_deep_fixel.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

src/deep_fixel/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def generate_odf(self, seed=None):
126126
if self.n_fibers == "both":
127127
n_fibers = self.rng.choice([2, 3])
128128
else:
129-
n_fibers = self.n_fibers
129+
n_fibers = int(self.n_fibers)
130130

131131
# Generate random volume fraction using Dirichlet distribution
132132
vol = self.rng.dirichlet(np.ones(n_fibers))
@@ -317,7 +317,7 @@ def generate_odf(self, seed=None):
317317
if self.n_fibers == "both":
318318
n_fibers = self.rng.choice([2, 3])
319319
else:
320-
n_fibers = self.n_fibers
320+
n_fibers = int(self.n_fibers)
321321

322322
# Generate random volume fraction using Dirichlet distribution
323323
vol = self.rng.dirichlet(np.ones(n_fibers))

src/deep_fixel/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
def train_mesh_model(
1515
run_name,
16+
project_name="deepfixel",
1617
lr=1e-3,
1718
batch_size=512,
1819
n_fibers=2,
@@ -56,7 +57,7 @@ def train_mesh_model(
5657

5758
# Set up Weights and Biases
5859
wandb.login()
59-
run = wandb.init(project="deepfixel", name=run_name, config=config)
60+
run = wandb.init(project=project_name, name=run_name, config=config)
6061

6162
# Set up datasets
6263
train_dataset = RandomMeshDataset(n_fibers=n_fibers, l_max=6, seed=seed, subdivide=mesh_subdivide, kappa=kappa, healpix=healpix, csd=csd, snr=snr)

0 commit comments

Comments
 (0)