Skip to content

Commit 749413f

Browse files
Update Dockerfile for inference
1 parent 45cf7f2 commit 749413f

File tree

6 files changed

+138
-40
lines changed

6 files changed

+138
-40
lines changed

Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,8 @@ PYPROJECT
5252
RUN --mount=type=cache,target=/root/.cache/uv \
5353
uv pip install /opt/hsd
5454

55+
RUN mkdir -p /app/models
56+
ADD https://zenodo.org/records/17859792/files/best_model_mlp.pth?download=1 /app/models/best_model_mlp.pth
57+
ADD https://zenodo.org/records/17859792/files/best_model_scnn.pth?download=1 /app/models/best_model_scnn.pth
58+
5559
ENV PATH="/app/.venv/bin:${PATH}"

README.md

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# DeepFixel: Crossing white matter fiber identification through spherical convolutional neural networks
2-
[![arXiv](https://img.shields.io/badge/arXiv-2511.03893-b31b1b.svg)](https://arxiv.org/abs/2511.03893) [![Zenodo](https://zenodo.org/badge/DOI/10.5281/zenodo.13121149.svg)](https://doi.org/10.5281/zenodo.13121149)
2+
[![arXiv](https://img.shields.io/badge/arXiv-2511.03893-b31b1b.svg)](https://arxiv.org/abs/2511.03893) [![Zenodo](https://zenodo.org/badge/DOI/10.5281/zenodo.17834289.svg)](https://doi.org/10.5281/zenodo.17834289)
33

44
DeepFixel is a deep learning method for identification of crossing fiber bundle elements from diffusion MRI.
55

@@ -14,11 +14,28 @@ uv sync
1414
Alternatively, you can use Docker or Apptainer (see instructions below).
1515

1616
## Usage
17-
To run the model, download the weights and testing dataset from the following link: [https://doi.org/10.5281/zenodo.13121149](https://doi.org/10.5281/zenodo.13121149).
17+
You can use DeepFixel to split multi-fiber ODFs into the underlying single-fiber ODFs. The pretrained weights are available on Zenodo: [https://doi.org/10.5281/zenodo.17834289](https://doi.org/10.5281/zenodo.17834289).
18+
19+
```bash
20+
# For pretrained models: --lmax 6, --subdivide 1
21+
deepfixel /path/to/input/fod.nii.gz \
22+
/path/to/output_dir \
23+
/path/to/best_model_scnn.pth \
24+
--mask /path/to/mask.nii.gz \
25+
--maxnum 2 \
26+
--lmax 6 \
27+
--subdivide 1 \
28+
--amp_threshold 0.1 \
29+
--model mesh_scnn \
30+
--batch_size 512 \
31+
--gpu_id 1
32+
```
33+
34+
## Training and testing the model
35+
To run the model, download the weights and testing dataset from the following link: [https://doi.org/10.5281/zenodo.17834289](https://doi.org/10.5281/zenodo.17834289).
1836
- Unzip and copy the testing data to `./test_data`
1937
- Put the weights in `./models/pretrained`
2038

21-
2239
To train the model:
2340
```bash
2441
python train_deep_fixel.py --config config/example_scnn.yaml
@@ -30,28 +47,58 @@ python test_deep_fixel.py --config config/example_scnn.yaml
3047
```
3148

3249

33-
## Usage (Docker)
50+
### Docker
3451
To build the Docker image, clone the repository and run the following command in the root directory:
3552
```bash
36-
sudo docker build -t spherical_deep_fixel:v1.0.0 .
53+
sudo docker build -t spherical_deep_fixel:v1.2.0 .
3754
```
3855

3956
Then run the Docker container with the following command (note you will likely need to bind in local directories with `-v`):
57+
4058
```bash
41-
sudo docker run --rm -it --gpus all -v $(pwd):$(pwd) $spherical_deep_fixel:v1.0.0 python train_deep_fixel.py --config /path/to/config/example_scnn.yaml
42-
sudo docker run --rm -it --gpus all -v $(pwd):$(pwd) $spherical_deep_fixel:v1.0.0 python test_deep_fixel.py --config /path/to/config/example_scnn.yaml
59+
sudo docker run --rm -it --gpus all \
60+
/path/to/input/fod.nii.gz \
61+
/path/to/output_dir \
62+
/app/models/best_model_scnn.pth \
63+
--mask /path/to/mask.nii.gz \
64+
--maxnum 2 \
65+
--lmax 6 \
66+
--subdivide 1 \
67+
--amp_threshold 0.1 \
68+
--model mesh_scnn \
69+
--batch_size 512 \
70+
--gpu_id 0
4371
```
4472

45-
## Usage (Apptainer)
46-
A pre-built Apptainer image is available on Zenodo ([https://doi.org/10.5281/zenodo.13121149](https://doi.org/10.5281/zenodo.13121149)):
73+
For training and testing:
74+
```bash
75+
sudo docker run --rm -it --gpus all $spherical_deep_fixel:v1.2.0 python train_deep_fixel.py --config /path/to/config/example_scnn.yaml
76+
sudo docker run --rm -it --gpus all $spherical_deep_fixel:v1.2.0 python test_deep_fixel.py --config /path/to/config/example_scnn.yaml
77+
```
78+
79+
### Apptainer
80+
A pre-built Apptainer image is available on Zenodo ([https://doi.org/10.5281/zenodo.17834289](https://doi.org/10.5281/zenodo.17834289)). (Note you will likely need to bind in local directories with `-B`):
4781

4882
```bash
49-
apptainer run -C -B $(pwd):$(pwd) --nv https://zenodo.org/records/17834290/files/spherical_deep_fixel_v1.0.0.sif python /app/train_deep_fixel.py --config /path/to/config/example_scnn.yaml
50-
apptainer run -C -B $(pwd):$(pwd) --nv https://zenodo.org/records/17834290/files/spherical_deep_fixel_v1.0.0.sif python /app/test_deep_fixel.py --config /path/to/config/example_scnn.yaml
83+
apptainer run -C --nv spherical_deep_fixel_v1.2.0.sif \
84+
/path/to/input/fod.nii.gz \
85+
/path/to/output_dir \
86+
/app/models/best_model_scnn.pth \
87+
--mask /path/to/mask.nii.gz \
88+
--maxnum 2 \
89+
--lmax 6 \
90+
--subdivide 1 \
91+
--amp_threshold 0.1 \
92+
--model mesh_scnn \
93+
--batch_size 512 \
94+
--gpu_id 0
5195
```
5296

53-
## Applying the model to your own data
54-
If you wish to apply the model to your own dataset, you can use `fissile.test_mesh_model()` as a basis for your code. You can also use `fissile.dataset.GeneratedMeshNIFTIDataset()` if your data is stored as spherical harmonic coefficients in a NIFTI file.
97+
For training and testing:
98+
```bash
99+
apptainer run -C --nv spherical_deep_fixel_v1.2.0.sif python /app/train_deep_fixel.py --config /path/to/config/example_scnn.yaml
100+
apptainer run -C --nv spherical_deep_fixel_v1.2.0.sif python /app/test_deep_fixel.py --config /path/to/config/example_scnn.yaml
101+
```
55102

56103
## Citation
57104
If you use this code in your research, please cite the following paper:

pyproject.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "deep_fixel"
7-
version = "1.0.0"
7+
version = "1.2.0"
88
authors = [
99
{name = "Adam Saunders", email = "adam.m.saunders@vanderbilt.edu"}
1010
]
@@ -41,5 +41,8 @@ dependencies = [
4141
"wandb>=0.19.7",
4242
]
4343

44+
[project.scripts]
45+
deepfixel = "deep_fixel.scripts.inference:main"
46+
4447
[project.urls]
45-
Homepage = "https://github.com/MASILab/deep_fixel"
48+
Homepage = "https://github.com/MASILab/spherical_deep_fixel"

src/deep_fixel/dataset.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from .utils import load_fissile_mat, fiber_response
2828

29+
2930
class RandomODFDataset(IterableDataset):
3031
def __init__(self, n_fibers, l_max=6, seed=None, size=None, deterministic=False):
3132
"""Generate ODFs at random angles and volume fractions (using Tournier07/mrtrix convention)
@@ -182,7 +183,8 @@ def __iter__(self):
182183
else:
183184
while True:
184185
yield self.generate_odf()
185-
186+
187+
186188
class RandomMeshDataset(IterableDataset):
187189
def __init__(
188190
self,
@@ -382,7 +384,7 @@ def generate_odf(self, seed=None):
382384
sh_order_max=self.l_max,
383385
basis_type="tournier07",
384386
)
385-
387+
386388
pdf = [
387389
v * vonmises_fisher(mu, self.kappa).pdf(self.icosphere.vertices)
388390
for v, mu in zip(vol, xyz.T)
@@ -614,37 +616,36 @@ def __getitem__(self, idx):
614616
class GeneratedMeshNIFTIDataset(Dataset):
615617
def __init__(
616618
self,
617-
n_fibers,
618619
nifti_path,
620+
mask=None,
621+
lmax=6,
619622
subdivide=3,
620-
kappa=100,
621623
healpix=False,
622624
):
623625
"""Load ODFs from a directory of .mat files from FISSILE outputs.
624626
625627
Parameters
626628
----------
627-
n_fibers : int or str
628-
Number of fibers in each ODF. If 'both', will use 2 and 3.
629629
nifti_path : str
630-
Path to nifti file
630+
Path to nifti file containing ODFs in spherical harmonic coefficients in tournier07/mrtrix convention
631+
mask : str
632+
Path to nifti file containing mask, by default None
633+
lmax : int, optional
634+
Maximum spherical harmonic order, by default 6
631635
subdivide : int, optional
632636
Number of times to subdivide the ico-hemisphere if healpix=False,
633637
otherwise corresponds to depth of Healpix sampling (where smaller is more vertices), by default 3
634-
kappa : float, optional
635-
Concentration parameter for von Mises-Fisher distribution, by default 100
636638
healpix : bool, optional
637639
If True, sample on healpix instead of icosphere, by default False
638640
"""
639-
self.n_fibers = n_fibers
640641
self.nifti_path = nifti_path
641-
self.l_max = 6
642+
self.l_max = lmax
642643

643644
if healpix:
644645
n_side = 8
645646
depth = subdivide
646647
patch_size = 1
647-
sh_degree = 6
648+
sh_degree = lmax
648649
pooling_mode = "average"
649650
pooling_name = "mixed"
650651
use_hemisphere = True
@@ -666,13 +667,19 @@ def __init__(
666667
self.n_mesh = len(self.icosphere.vertices)
667668
self.sphere = self.icosphere
668669

669-
self.kappa = kappa
670-
671670
# Load NIFTI
672671
nifti = nib.load(nifti_path)
672+
self.shape = nifti.shape
673+
self.affine = nifti.affine
673674

674675
# Flatten all except last axis
675676
nifti_data = nifti.get_fdata().squeeze()
677+
678+
if mask is not None:
679+
mask_nifti = nib.load(mask)
680+
mask_data = mask_nifti.get_fdata().squeeze().astype(bool)
681+
nifti_data = nifti_data[mask_data]
682+
676683
nifti_data = nifti_data.reshape(-1, nifti_data.shape[-1])
677684

678685
# Append each ODF in file to list
@@ -689,4 +696,4 @@ def __len__(self):
689696
return len(self.total_odf_meshes)
690697

691698
def __getitem__(self, idx):
692-
return self.total_odf_meshes[idx]
699+
return self.total_odf_meshes[idx]

src/deep_fixel/utils.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
from scipy.signal import argrelmax
2121
from scipy.spatial.transform import Rotation as R
2222
from scipy.io import loadmat
23-
import cmcrameri
23+
import cmcrameri
2424
from line_profiler import profile
2525

26-
def plot_odf(odf, ax=None, color="blue", basis="tournier", alpha=1, linewidth=0.1, sphere=None):
26+
27+
def plot_odf(
28+
odf, ax=None, color="blue", basis="tournier", alpha=1, linewidth=0.1, sphere=None
29+
):
2730
"""Plot a spherical orientation distribution function represented by spherical harmonic coefficients.
2831
2932
Parameters
@@ -413,7 +416,10 @@ def match_odfs(true_odfs, est_odfs):
413416

414417
return matched_est_odfs, index_array
415418

416-
def pdf2odfs(mesh, sphere, amp_threshold=0.5, use_dipy=False, **kwargs):
419+
420+
def pdf2odfs(
421+
mesh, sphere, amp_threshold=0.5, use_dipy=False, lmax=6, max_num=None, **kwargs
422+
):
417423
"""
418424
Estimate ODFs from spherical PDF
419425
@@ -426,7 +432,11 @@ def pdf2odfs(mesh, sphere, amp_threshold=0.5, use_dipy=False, **kwargs):
426432
amp_threshold : float, optional
427433
Amplitude threshold for maxima, by default 0.5
428434
use_dipy : bool, optional
429-
Whether to use Dipy's peak finding, by default True.
435+
Whether to use Dipy's peak finding, by default False.
436+
lmax : int, optional
437+
Maximum spherical harmonic degree, by default 6
438+
max_num : int, optional
439+
Maximum number of peaks to find, by default None
430440
**kwargs : dict, optional
431441
Additional keyword arguments for peak finding if using Dipy.
432442
@@ -439,6 +449,8 @@ def pdf2odfs(mesh, sphere, amp_threshold=0.5, use_dipy=False, **kwargs):
439449
vol_fracs : NumPy array
440450
Estimated volume fractions
441451
"""
452+
m_list, l_list = sph_harm_ind_list(lmax)
453+
442454
if use_dipy == "nl":
443455
points = np.array([sphere.theta, sphere.phi]).T
444456
pdf_mesh_interp = CloughTocher2DInterpolator(points, mesh, fill_value=0)
@@ -449,13 +461,19 @@ def pdf2odfs(mesh, sphere, amp_threshold=0.5, use_dipy=False, **kwargs):
449461
pdf_mesh_interp_sphere,
450462
relative_peak_threshold=amp_threshold,
451463
sphere=sphere,
452-
**kwargs
464+
**kwargs,
453465
)
466+
if max_num is not None:
467+
# Keep only top max_num peaks
468+
if vals.shape[0] > max_num:
469+
top_indices = np.argsort(vals)[-max_num:]
470+
xyz = xyz[top_indices]
471+
vals = vals[top_indices]
472+
454473
vol_fracs = vals / np.sum(vals) # Normalize volume fractions
455474
r, theta, phi = cart2sphere(xyz[:, 0], xyz[:, 1], xyz[:, 2])
456475
dirs = np.array([theta, phi]).T
457476
odfs = []
458-
m_list, l_list = sph_harm_ind_list(6)
459477
for theta, phi, vol_frac in zip(dirs[:, 0], dirs[:, 1], vol_fracs):
460478
odfs.append(
461479
convert_sh_descoteaux_tournier(
@@ -473,11 +491,21 @@ def pdf2odfs(mesh, sphere, amp_threshold=0.5, use_dipy=False, **kwargs):
473491
**kwargs,
474492
)
475493

494+
if len(vals) == 0:
495+
return np.zeros((1, len(m_list))), np.zeros((1, 2)), np.zeros((1,))
496+
497+
if max_num is not None:
498+
# Keep only top max_num peaks
499+
if vals.shape[0] > max_num:
500+
top_indices = np.argsort(vals)[-max_num:]
501+
xyz = xyz[top_indices]
502+
vals = vals[top_indices]
503+
476504
vol_fracs = vals / np.sum(vals) # Normalize volume fractions
477505
r, theta, phi = cart2sphere(xyz[:, 0], xyz[:, 1], xyz[:, 2])
478506
dirs = np.array([theta, phi]).T
479507
odfs = []
480-
m_list, l_list = sph_harm_ind_list(6)
508+
m_list, l_list = sph_harm_ind_list(lmax)
481509
for theta, phi, vol_frac in zip(dirs[:, 0], dirs[:, 1], vol_fracs):
482510
odfs.append(
483511
convert_sh_descoteaux_tournier(
@@ -527,11 +555,18 @@ def pdf2odfs(mesh, sphere, amp_threshold=0.5, use_dipy=False, **kwargs):
527555
minima = minima[minima_vals > amp_threshold]
528556
minima_vals = minima_vals[minima_vals > amp_threshold]
529557

558+
if max_num is not None:
559+
# Keep only top max_num peaks
560+
if minima_vals.shape[0] > max_num:
561+
top_indices = np.argsort(minima_vals)[-max_num:]
562+
minima = minima[top_indices]
563+
minima_vals = minima_vals[top_indices]
564+
530565
# Estimate volume fraction using ratio of amplitude at minima
531566
minima_vals = minima_vals / np.sum(minima_vals)
532567

533568
# Now estimate ODF at these points with these volume fractions
534-
m_list, l_list = sph_harm_ind_list(6)
569+
m_list, l_list = sph_harm_ind_list(lmax)
535570
odfs = []
536571
for min, min_val in zip(minima, minima_vals):
537572
odfs.append(
@@ -570,6 +605,7 @@ def angular_separation(angle1, angle2):
570605
delta = np.arccos(cos_delta)
571606
return delta
572607

608+
573609
def load_fissile_mat(path):
574610
"""Load output from FISSILE and match estimated fibers to true fibers.
575611
@@ -671,6 +707,7 @@ def load_fissile_mat(path):
671707

672708
return data_dict
673709

710+
674711
def fiber_response(
675712
sphere,
676713
theta=0,
@@ -707,4 +744,4 @@ def fiber_response(
707744
response_amp = np.exp(-bval * lambda_perp) * np.exp(
708745
-3 * bval * (lambda_mean - lambda_perp) * (cos_theta**2)
709746
)
710-
return response_amp
747+
return response_amp

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)