diff --git a/PialNN/pialnn/1.0.0/config.py b/PialNN/pialnn/1.0.0/config.py new file mode 100644 index 00000000..fc9a4c83 --- /dev/null +++ b/PialNN/pialnn/1.0.0/config.py @@ -0,0 +1,33 @@ +import argparse + +def load_config(): + + # args + parser = argparse.ArgumentParser(description="PialNN") + + # data + parser.add_argument('--data_path', default="./data/train/", type=str, help="path of the dataset") + parser.add_argument('--hemisphere', default="lh", type=str, help="left or right hemisphere (lh or rh)") + # model file + parser.add_argument('--model', help="path to best model") + #model + parser.add_argument('--nc', default=128, type=int, help="num of channels") + parser.add_argument('--K', default=5, type=int, help="kernal size") + parser.add_argument('--n_scale', default=3, type=int, help="num of scales for image pyramid") + parser.add_argument('--n_smooth', default=1, type=int, help="num of Laplacian smoothing layers") + parser.add_argument('--lambd', default=1.0, type=float, help="Laplacian smoothing weights") + # training + parser.add_argument('--train_data_ratio', default=0.8, type=float, help="percentage of training data") + parser.add_argument('--lr', default=1e-4, type=float, help="learning rate") + parser.add_argument('--n_epoch', default=200, type=int, help="total training epochs") + parser.add_argument('--ckpts_interval', default=10, type=int, help="save checkpoints after each n epoch") + parser.add_argument('--report_training_loss', default=True, type=bool, help="if report training loss") + parser.add_argument('--save_model', default=True, type=bool, help="if save training models") + parser.add_argument('--save_mesh_train', default=False, type=bool, help="if save mesh during training") + # evaluation + parser.add_argument('--save_mesh_eval', default=False, type=bool, help="if save mesh during evaluation") + parser.add_argument('--n_test_pts', default=150000, type=int, help="num of points sampled for evaluation") + + config = parser.parse_args() + + return config diff --git a/PialNN/pialnn/1.0.0/data/dataload.py b/PialNN/pialnn/1.0.0/data/dataload.py new file mode 100644 index 00000000..53c0cfc4 --- /dev/null +++ b/PialNN/pialnn/1.0.0/data/dataload.py @@ -0,0 +1,105 @@ +import os +import numpy as np +import torch +from tqdm import tqdm +import nibabel as nib +from torch.utils.data import Dataset + + +""" +volume: brain MRI volume +v_in: vertices of input white matter surface +f_in: faces of ground truth pial surface +v_gt: vertices of input white matter surface +f_gt: faces of ground truth pial surface +""" + +class BrainData(): + def __init__(self, volume, v_in, v_gt, f_in, f_gt): + self.v_in = torch.Tensor(v_in) + self.v_gt = torch.Tensor(v_gt) + self.f_in = torch.LongTensor(f_in) + self.f_gt = torch.LongTensor(f_gt) + self.volume = torch.Tensor(volume).unsqueeze(0) + + +class BrainDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + brain = self.data[i] + return brain.volume, brain.v_gt, \ + brain.f_gt, brain.v_in, brain.f_in + + +def load_mri(path): + + brain = nib.load(path) + brain_arr = brain.get_fdata() + brain_arr = brain_arr / 255. + + # ====== change to your own transformation ====== + # transpose and clip the data to [192,224,192] + brain_arr = brain_arr.transpose(1,2,0) + brain_arr = brain_arr[::-1,:,:] + brain_arr = brain_arr[:,:,::-1] + brain_arr = brain_arr[32:-32, 16:-16, 32:-32] + #================================================ + + return brain_arr.copy() + + +def load_surf(path): + v, f = nib.freesurfer.io.read_geometry(path) + + # ====== change to your own transformation ====== + # transpose and clip the data to [192,224,192] + v = v[:,[0,2,1]] + v[:,0] = v[:,0] - 32 + v[:,1] = - v[:,1] - 15 + v[:,2] = v[:,2] - 32 + + # normalize to [-1, 1] + v = v + 128 + v = (v - [96, 112, 96]) / 112 + f = f.astype(np.int32) + #================================================ + + return v, f + + +def load_data(data_path, hemisphere): + """ + data path: path of dataset + """ + + subject_lists = sorted(os.listdir(data_path)) + + dataset = [] + + for i in tqdm(range(len(subject_lists))): + + subid = subject_lists[i] + + # load brain MRI + volume = load_mri(data_path + subid + '/mri/orig.mgz') + + # load ground truth pial surface + v_gt, f_gt = load_surf(data_path + subid + '/surf/' + hemisphere + '.pial') +# v_gt, f_gt = load_surf(data_path + subid + '/surf/' + hemisphere + '.pial.deformed') + + # load input white matter surface + v_in, f_in = load_surf(data_path + subid + '/surf/' + hemisphere + '.white') +# v_in, f_in = load_surf(data_path + subid + '/surf/' + hemisphere + '.white.deformed') + + braindata = BrainData(volume=volume, v_gt=v_gt, f_gt=f_gt, + v_in=v_in, f_in=f_in) + dataset.append(braindata) + + return dataset + + diff --git a/PialNN/pialnn/1.0.0/data/test/example/mri/orig.mgz b/PialNN/pialnn/1.0.0/data/test/example/mri/orig.mgz new file mode 100644 index 00000000..973eaadb Binary files /dev/null and b/PialNN/pialnn/1.0.0/data/test/example/mri/orig.mgz differ diff --git a/PialNN/pialnn/1.0.0/data/test/example/surf/lh.pial b/PialNN/pialnn/1.0.0/data/test/example/surf/lh.pial new file mode 100644 index 00000000..b6bbff39 Binary files /dev/null and b/PialNN/pialnn/1.0.0/data/test/example/surf/lh.pial differ diff --git a/PialNN/pialnn/1.0.0/data/test/example/surf/lh.white b/PialNN/pialnn/1.0.0/data/test/example/surf/lh.white new file mode 100644 index 00000000..2d5a8a2d Binary files /dev/null and b/PialNN/pialnn/1.0.0/data/test/example/surf/lh.white differ diff --git a/PialNN/pialnn/1.0.0/data/train/SUBJECT_NAME/mri/mri_data_here.txt b/PialNN/pialnn/1.0.0/data/train/SUBJECT_NAME/mri/mri_data_here.txt new file mode 100644 index 00000000..fa7cc9cb --- /dev/null +++ b/PialNN/pialnn/1.0.0/data/train/SUBJECT_NAME/mri/mri_data_here.txt @@ -0,0 +1 @@ +### \ No newline at end of file diff --git a/PialNN/pialnn/1.0.0/data/train/SUBJECT_NAME/surf/surface_data_here.txt b/PialNN/pialnn/1.0.0/data/train/SUBJECT_NAME/surf/surface_data_here.txt new file mode 100644 index 00000000..fa7cc9cb --- /dev/null +++ b/PialNN/pialnn/1.0.0/data/train/SUBJECT_NAME/surf/surface_data_here.txt @@ -0,0 +1 @@ +### \ No newline at end of file diff --git a/PialNN/pialnn/1.0.0/docker/Dockerfile b/PialNN/pialnn/1.0.0/docker/Dockerfile new file mode 100644 index 00000000..0b4df319 --- /dev/null +++ b/PialNN/pialnn/1.0.0/docker/Dockerfile @@ -0,0 +1,51 @@ +FROM nvidia/cuda:12.2.0-devel-ubuntu20.04 +ENV LANG=C.UTF-8 +ENV LC_ALL=C.UTF-8 +ENV PATH=/opt/miniconda3/bin:$PATH +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 +ENV PYTHONIOENCODING=UTF-8 +ENV PIPENV_VENV_IN_PROJECT=1 +ENV JCC_JDK=/usr/lib/jvm/java-8-openjdk-amd64 +RUN USE_CUDA=1 +RUN CUDA_VERSION=12.2.0 +RUN CUDNN_VERSION=8 +RUN LINUX_DISTRO=ubuntu +RUN DISTRO_VERSION=20.04 +RUN TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6" +RUN rm -f /etc/apt/apt.conf.d/docker-clean; \ +echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' \ +> /etc/apt/apt.conf.d/keep-cache +RUN apt-get update && DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get -y install tzdata && apt-get install -y --no-install-recommends \ +build-essential \ +ca-certificates \ +ccache \ +curl \ +git \ +wget \ +cmake \ +gfortran \ +libspatialindex-dev +RUN rm -rf /var/lib/apt/lists/* +ENV PYTHON_VERSION=3.7 +ENV CONDA_URL=https://repo.anaconda.com/miniconda/Miniconda3-py37_4.10.3-Linux-x86_64.sh +RUN curl -fsSL -v -o ~/miniconda.sh -O ${CONDA_URL} && \ +chmod +x ~/miniconda.sh && \ +~/miniconda.sh -b -p /opt/miniconda3 + +WORKDIR /app +COPY pialnn.requirements.txt . +COPY environment.yml . +RUN conda env update -f environment.yml +SHELL ["conda", "run", "-n", "base", "/bin/bash", "-c"] +#RUN pip cache purge +RUN pip install torch-scatter==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.7.0.html +RUN pip install torch-sparse==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.7.0.html +RUN pip install torch-cluster==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.7.0.html +RUN pip install torch-spline-conv==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.7.0.html +RUN pip install torch-geometric +RUN pip install rtree +RUN pip install -r pialnn.requirements.txt +RUN conda install -c conda-forge libspatialindex=1.9.3 +RUN conda clean -a +ENTRYPOINT ["/bin/bash", "-l", "-c"] diff --git a/PialNN/pialnn/1.0.0/docker/environment.yml b/PialNN/pialnn/1.0.0/docker/environment.yml new file mode 100644 index 00000000..322a1440 --- /dev/null +++ b/PialNN/pialnn/1.0.0/docker/environment.yml @@ -0,0 +1,93 @@ +name: base +channels: + - pytorch3d + - pytorch + - bottler + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - ca-certificates=2022.9.24=ha878542_0 + - certifi=2022.9.24=pyhd8ed1ab_0 + - colorama=0.4.6=pyhd8ed1ab_0 + - cudatoolkit=10.2.89=hfd86e86_1 + - dataclasses=0.8=pyhc8e2a94_3 + - freetype=2.12.1=h4a9f257_0 + - fvcore=0.1.5.post20220512=pyhd8ed1ab_0 + - giflib=5.2.1=h7b6447c_0 + - intel-openmp=2021.4.0=h06a4308_3561 + - iopath=0.1.9=pyhd8ed1ab_0 + - jpeg=9e=h7f8727e_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libdeflate=1.8=h7f8727e_5 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libpng=1.6.37=hbc83047_0 + - libspatialindex=1.9.3=h9c3ff4c_4 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtiff=4.4.0=hecacb30_0 + - libuv=1.40.0=h7b6447c_0 + - libwebp=1.2.4=h11a3e52_0 + - libwebp-base=1.2.4=h5eee18b_0 + - lz4-c=1.9.3=h295c915_1 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py37h7f8727e_0 + - mkl_fft=1.3.1=py37hd3c417c_0 + - mkl_random=1.2.2=py37h51133e4_0 + - ncurses=6.3=h5eee18b_3 + - ninja=1.10.2=h06a4308_5 + - ninja-base=1.10.2=hd09550d_5 + - numpy=1.21.5=py37h6c91a56_3 + - numpy-base=1.21.5=py37ha15fc14_3 + - nvidiacub=1.10.0=0 + - openssl=1.1.1s=h7f8727e_0 + - pillow=9.2.0=py37hace64e9_1 + - portalocker=2.6.0=py37h89c1867_0 + - python=3.7.13=haa1d7c7_1 + - python_abi=3.7=2_cp37m + - pytorch=1.7.0=py3.7_cuda10.2.89_cudnn7.6.5_0 + - pytorch3d=0.6.2=py37_cu102_pyt170 + - pyyaml=6.0=py37h540881e_4 + - readline=8.2=h5eee18b_0 + - setuptools=65.4.0=py37h06a4308_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.39.3=h5082296_0 + - tabulate=0.9.0=pyhd8ed1ab_1 + - termcolor=2.0.1=pyhd8ed1ab_1 + - tk=8.6.12=h1ccaba5_0 + - torchvision=0.8.1=py37_cu102 + - tqdm=4.64.1=pyhd8ed1ab_0 + - typing_extensions=4.3.0=py37h06a4308_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.6=h5eee18b_0 + - yacs=0.1.8=pyhd8ed1ab_0 + - yaml=0.2.5=h7f98852_2 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.2=ha4553b6_0 + - pip: + - charset-normalizer==2.1.1 + - idna==3.4 + - jinja2==3.1.2 + - joblib==1.2.0 + - markupsafe==2.1.1 + - nvidia-ml-py3==7.352.0 + - packaging==21.3 + - pip==18.0 + - pyparsing==3.0.9 + - requests==2.28.1 + - rtree==1.0.1 + - scikit-learn==1.0.2 + - scipy==1.7.3 + - threadpoolctl==3.1.0 + - torch-cluster==1.5.8 + - torch-geometric==2.1.0.post1 + - torch-scatter==2.0.5 + - torch-sparse==0.6.8 + - torch-spline-conv==1.2.0 + - trimesh==3.15.8 + - urllib3==1.26.12 diff --git a/PialNN/pialnn/1.0.0/docker/pialnn.requirements.txt b/PialNN/pialnn/1.0.0/docker/pialnn.requirements.txt new file mode 100644 index 00000000..e6cb514b --- /dev/null +++ b/PialNN/pialnn/1.0.0/docker/pialnn.requirements.txt @@ -0,0 +1,47 @@ +certifi==2022.9.24 +charset-normalizer==2.1.1 +colorama==0.4.6 +dataclasses==0.8 +freesurfer-surface==2.0.0 +freetype-py==2.3.0 +fvcore==0.1.5.post20220512 +idna==3.3 +iopath==0.1.9 +Jinja2==3.1.2 +joblib==1.2.0 +MarkupSafe==2.1.1 +mkl-fft==1.3.1 +mkl-random==1.2.2 +mkl-service==2.4.0 +nibabel==3.2.1 +nilearn==0.8.1 +numpy==1.21.5 +nvidia-ml-py3==7.352.0 +packaging==21.3 +Pillow==9.2.0 +portalocker==2.6.0 +PyOpenGL==3.1.0 +pyparsing==3.0.9 +pyrender==0.1.45 +pytorch3d==0.6.2 +PyYAML==6.0 +requests==2.28.1 +Rtree==1.0.1 +scikit-learn==1.0.2 +scipy==1.7.3 +six==1.16.0 +tabulate==0.9.0 +termcolor==2.0.1 +threadpoolctl==3.1.0 +torch==1.7.0 +torch-cluster==1.5.8 +torch-geometric==2.1.0.post1 +torch-scatter==2.0.5 +torch-sparse==0.6.8 +torch-spline-conv==1.2.0 +torchvision==0.8.1 +tqdm==4.64.1 +trimesh==3.15.8 +typing-extensions==4.3.0 +urllib3==1.26.12 +yacs==0.1.8 diff --git a/PialNN/pialnn/1.0.0/example-data/sample.zip b/PialNN/pialnn/1.0.0/example-data/sample.zip new file mode 120000 index 00000000..0b17fd67 --- /dev/null +++ b/PialNN/pialnn/1.0.0/example-data/sample.zip @@ -0,0 +1 @@ +../../../../.git/annex/objects/XG/kw/MD5E-s9777720--5d02bc50297786ffde23197fe0db6e1b.zip/MD5E-s9777720--5d02bc50297786ffde23197fe0db6e1b.zip \ No newline at end of file diff --git a/PialNN/pialnn/1.0.0/model/pialnn.py b/PialNN/pialnn/1.0.0/model/pialnn.py new file mode 100644 index 00000000..505e8bd8 --- /dev/null +++ b/PialNN/pialnn/1.0.0/model/pialnn.py @@ -0,0 +1,201 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from utils import compute_normal + + +""" +Deformation Block +nc: number of channels +K: kernal size for local conv operation +n_scale: num of layers of image pyramid +""" + +class DeformBlock(nn.Module): + def __init__(self, nc=128, K=5, n_scale=3): + super(DeformBlock, self).__init__() + + # mlp layers + self.fc1 = nn.Linear(6, nc) + self.fc2 = nn.Linear(nc*2, nc*4) + self.fc3 = nn.Linear(nc*4, nc*2) + self.fc4 = nn.Linear(nc*2, 3) + + # for local convolution operation + self.localconv = nn.Conv3d(n_scale, nc, (K, K, K)) + self.localfc = nn.Linear(nc, nc) + + self.n_scale = n_scale + self.nc = nc + self.K = K + + def forward(self, v, f, volume): + + coord = v.clone() + normal = compute_normal(v, f) # compute normal + + # point feature + x = torch.cat([v, normal], 2) + x = F.leaky_relu(self.fc1(x), 0.15) + + # local feature + cubes = self.cube_sampling(v, volume) # extract K^3 cubes + x_local = self.localconv(cubes) + x_local = x_local.view(1, v.shape[1], self.nc) + x_local = self.localfc(x_local) + + # fusion + x = torch.cat([x, x_local], 2) + x = F.leaky_relu(self.fc2(x), 0.15) + x = F.leaky_relu(self.fc3(x), 0.15) + x = torch.tanh(self.fc4(x)) * 0.1 # threshold the displacement + + return coord + x # v=v+dv + + def initialize(self, L, W, H, device=None): + """initialize necessary constants""" + + LWHmax = max([L,W,H]) + self.LWHmax = LWHmax + # rescale to [-1, 1] + self.rescale = torch.Tensor([L/LWHmax, W/LWHmax, H/LWHmax]).to(device) + + # shape of mulit-scale image pyramid + self.pyramid_shape = torch.zeros([self.n_scale, 3]).to(device) + for i in range(self.n_scale): + self.pyramid_shape[i] = torch.Tensor([L/(2**i), + W/(2**i), + H/(2**i)]).to(device) + # for threshold + self.lower_bound = torch.tensor([(self.K-1)//2, + (self.K-1)//2, + (self.K-1)//2]).to(device) + # for storage of sampled cubes + self.cubes_holder = torch.zeros([1, self.n_scale, + self.K, self.K, self.K]).to(device) + + def cube_sampling(self, v, volume): + + # for storage of sampled cubes + cubes = self.cubes_holder.repeat(v.shape[1],1,1,1,1) + + # 3D MRI volume + vol_ = volume.clone() + for n in range(self.n_scale): # multi scales + if n > 0: + vol_ = F.avg_pool3d(vol_, 2) # down sampling + vol = vol_[0,0] + + # find corresponding position + indices = (v[0] + self.rescale) * self.LWHmax / (2**(n+1)) + indices = torch.round(indices).long() + indices = torch.max(torch.min(indices, self.pyramid_shape[n]-3), + self.lower_bound).long() + + # sample values of each cube + for i in [-2,-1,0,1,2]: + for j in [-2,-1,0,1,2]: + for k in [-2,-1,0,1,2]: + cubes[:,n,2+i,2+j,2+k] = vol[indices[:,2]+i, + indices[:,1]+j, + indices[:,0]+k] + return cubes + + + +""" +PialNN with 3 deformation blocks + 1 Laplacian smoothing layer +""" + +class PialNN(nn.Module): + def __init__(self, nc=128, K=5, n_scale=3): + super(PialNN, self).__init__() + self.block1 = DeformBlock(nc, K, n_scale) + self.block2 = DeformBlock(nc, K, n_scale) + self.block3 = DeformBlock(nc, K, n_scale) + self.smooth = LaplacianSmooth(3, 3, aggr='mean') + + def forward(self, v, f, volume, n_smooth=1, lambd=1.0): + + x = self.block1(v, f, volume) + x = self.block2(x, f, volume) + x = self.block3(x, f, volume) + edge_list = torch.cat([f[0,:,[0,1]], + f[0,:,[1,2]], + f[0,:,[2,0]]], dim=0).transpose(1,0) + + for i in range(n_smooth): + x = self.smooth(x, edge_list, lambd=lambd) + + return x + + def initialize(self, L=256, W=256, H=256, device=None): + self.block1.initialize(L,W,H,device) + self.block2.initialize(L,W,H,device) + self.block3.initialize(L,W,H,device) + + + +""" +LaplacianSmooth() is a differentiable Laplacian smoothing layer. +The code is implemented based on the torch_geometric.nn.conv.GraphConv. +For original GraphConv implementation, please see +https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/conv/graph_conv.py + + +x: the coordinates of the vertices, (|V|, 3). +edge_index: the list of edges, (2, |E|), e.g. [[0,1],[1,3],...]. +lambd: weight for Laplacian smoothing, between [0,1]. +out: the smoothed vertices, (|V|, 3). +""" + +from typing import Union, Tuple +from torch_geometric.typing import OptTensor, OptPairTensor, Adj, Size +from torch import Tensor +from torch_sparse import SparseTensor, matmul +from torch_geometric.nn.conv import MessagePassing + + +class LaplacianSmooth(MessagePassing): + + def __init__(self, in_channels: Union[int, Tuple[int, + int]], out_channels: int, + aggr: str = 'add', bias: bool = True, **kwargs): + super(LaplacianSmooth, self).__init__(aggr=aggr, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, + edge_weight: OptTensor = None, size: Size = None, lambd=0.5) -> Tensor: + + if isinstance(x, Tensor): + x: OptPairTensor = (x, x) + + # propagate_type: (x: OptPairTensor, edge_weight: OptTensor) + out = self.propagate(edge_index, x=x, edge_weight=edge_weight, + size=size) + out = lambd * out + x_r = x[1] + if x_r is not None: + out += (1-lambd) * x_r + + return out + + + def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j + + def message_and_aggregate(self, adj_t: SparseTensor, + x: OptPairTensor) -> Tensor: + return matmul(adj_t, x[0], reduce=self.aggr) + + def __repr__(self): + return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, + self.out_channels) diff --git a/PialNN/pialnn/1.0.0/model_card.yaml b/PialNN/pialnn/1.0.0/model_card.yaml new file mode 100644 index 00000000..c70597d8 --- /dev/null +++ b/PialNN/pialnn/1.0.0/model_card.yaml @@ -0,0 +1,33 @@ +Model_details: + Organization: DeepCSR + Model_date: t + Model_version: 1.0 + Model_type: t + More_information: t + Citation_details: t + Contact_info: t +Intended_use: + Primary_intended_uses: t + Primary_intended_users: t + Out_of_scope_use_cases: t +Factors: + Relevant_factors: t + Evaluation_factors: t + Model_performance_measures: t +Metrics: + Model Performance Measures: t + Decision Thresholds: t + Variation Approaches: t +Evaluation Data: + Datasets: t + Motivation: t + Preprocessing: t +Training Data: + Datasets: t + Motivation: t + Preprocessing: t +Quantitative Analyses: + Unitary Results: t + Intersectional Results: t +Ethical Considerations: t +Caveats and Recommendations: t diff --git a/PialNN/pialnn/1.0.0/predict.py b/PialNN/pialnn/1.0.0/predict.py new file mode 100644 index 00000000..6b34b99c --- /dev/null +++ b/PialNN/pialnn/1.0.0/predict.py @@ -0,0 +1,106 @@ +import numpy as np +from tqdm import tqdm + +import torch +from torch.utils.data import DataLoader + +from config import load_config +from data.dataload import load_data, BrainDataset +from model.pialnn import PialNN +from utils import compute_normal, save_mesh_obj, compute_distance + + +if __name__ == '__main__': + + """set device""" + if torch.cuda.is_available(): + device_name = "cuda:0" + print('selected gpu') + else: + device_name = "cpu" + device = torch.device(device_name) + + + """load configuration""" + config = load_config() + + """load dataset""" + print("----------------------------") + print("Start loading dataset ...") + test_data = load_data(data_path = config.data_path, + hemisphere = config.hemisphere) + n_data = len(test_data) + L,W,H = test_data[0].volume[0].shape # shape of MRI + LWHmax = max([L,W,H]) + + test_set = BrainDataset(test_data) + testloader = DataLoader(test_set, batch_size=1, shuffle=True) + print("Finish loading dataset. There are total {} subjects.".format(n_data)) + print("----------------------------") + + + """load model""" + print("Start loading model ...") + model = PialNN(config.nc, config.K, config.n_scale).to(device) + model.load_state_dict(torch.load(config.model, + map_location=device)) + model.initialize(L, W, H, device) + print("Finish loading model") + print("----------------------------") + + + """evaluation""" + print("Start evaluation ...") + with torch.no_grad(): + #CD = [] + #AD = [] + #HD = [] + for idx, data in tqdm(enumerate(testloader)): + volume_in, v_gt, f_gt, v_in, f_in = data + + sub_id = idx + + volume_in = volume_in.to(device) + v_gt = v_gt.to(device) + f_gt = f_gt.to(device) + v_in = v_in.to(device) + f_in = f_in.to(device) + + # set n_smooth > 1 if the mesh quality is not good + v_pred = model(v=v_in, f=f_in, volume=volume_in, + n_smooth=config.n_smooth, lambd=config.lambd) + + v_pred_eval = v_pred[0].cpu().numpy() * LWHmax/2 + [L/2,W/2,H/2] + f_pred_eval = f_in[0].cpu().numpy() + v_gt_eval = v_gt[0].cpu().numpy() * LWHmax/2 + [L/2,W/2,H/2] + f_gt_eval = f_gt[0].cpu().numpy() + + # compute distance-based metrics + #cd, assd, hd = compute_distance(v_pred_eval, v_gt_eval, + # f_pred_eval, f_gt_eval, config.n_test_pts) + + #CD.append(cd) + #AD.append(assd) + #HD.append(hd) + print('sub_id',sub_id) + if config.save_mesh_eval: + path_save_mesh = "./pialnn_mesh_eval_"\ + +config.hemisphere+"_subject_"+str(sub_id)+".obj" + + normal = compute_normal(v_pred, f_in) + n_pred_eval = normal[0].cpu().numpy() + save_mesh_obj(v_pred_eval, f_pred_eval, n_pred_eval, path_save_mesh) + + ################ + path_save_mesh = "./pialnn_mesh_eval_"\ + +config.hemisphere+"_subject_"+str(sub_id)+"_gt.obj" + + normal = compute_normal(v_gt, f_gt) + n_gt_eval = normal[0].cpu().numpy() + save_mesh_obj(v_gt_eval, f_gt_eval, n_gt_eval, path_save_mesh) + + # print("CD: Mean={}, Std={}".format(np.mean(CD), np.std(CD))) + # print("AD: Mean={}, Std={}".format(np.mean(AD), np.std(AD))) + # print("HD: Mean={}, Std={}".format(np.mean(HD), np.std(HD))) + print("Finish evaluation.") + print("----------------------------") diff --git a/PialNN/pialnn/1.0.0/spec.yaml b/PialNN/pialnn/1.0.0/spec.yaml new file mode 100644 index 00000000..4e53730a --- /dev/null +++ b/PialNN/pialnn/1.0.0/spec.yaml @@ -0,0 +1,47 @@ +image: + docker: neuronets/deepcsr + singularity: nobrainer-zoo_deepcsr.sif +repository: + repo_url: None + committish: None + repo_download: 'False' + repo_download_location: None +inference: + prediction_script: trained-models/DeepCSR/deepcsr/1.0/predict.py + command: f"python {MODELS_PATH}/{model}/predict.py --conf_path {conf} --model_checkpoint + {infile[0]} --dataset {infile[1]}" + data_spec: + infile: + n_files: 1 + outfile: + n_files: 1 +training_data_info: + data_number: + total: 1 + train: 1 + evaluate: 1 + test: 1 + biological_sex: + male: null + female: null + age_histogram: '1' + race: '1' + imaging_contrast_info: '1' + dataset_sources: '1' + data_sites: + number_of_sites: 1 + sites: '1' + scanner_models: '1' + hardware: '1' + training_parameters: + input_shape: '1' + block_shape: '1' + n_classes: 1 + lr: '1' + n_epochs: 1 + total_batch_size: 1 + number_of_gpus: 1 + loss_function: '1' + metrics: '1' + data_preprocessing: '1' + data_augmentation: '1' diff --git a/PialNN/pialnn/1.0.0/utils.py b/PialNN/pialnn/1.0.0/utils.py new file mode 100644 index 00000000..90a83d10 --- /dev/null +++ b/PialNN/pialnn/1.0.0/utils.py @@ -0,0 +1,61 @@ +import numpy as np +import pytorch3d +from pytorch3d.structures import Meshes +import trimesh +from trimesh.exchange.obj import export_obj +from scipy.spatial import cKDTree + + +def compute_normal(v, f): + """v, f: Tensors""" + normal = Meshes(verts=list(v), + faces=list(f)).verts_normals_list()[0] + return normal.unsqueeze(0) + + +def save_mesh_obj(v, f, n, path): + mesh_save = trimesh.Trimesh(vertices=v, + faces=f, + vertex_normals=n) + obj_save = export_obj(mesh_save, include_normals=True) + with open(path, "w") as file: + print(obj_save, file=file) + + +def compute_distance(v_pred, v_gt, f_pred, f_gt, n_samples=150000): + """ + The results are evaluated based on three distances: + 1. Chamfer Distance (CD) + 2. Average Absolute Distance (AD) + 3. Hausdorff Distance (HD) + + Please see DeepCSR paper in details: + https://arxiv.org/abs/2010.11423 + + For original code, please see: + https://bitbucket.csiro.au/projects/CRCPMAX/repos/deepcsr/browse/eval.py + """ + + # chamfer distance + cd = 0 + kdtree = cKDTree(v_pred) + cd += kdtree.query(v_gt)[0].mean()/2 + kdtree = cKDTree(v_gt) + cd += kdtree.query(v_pred)[0].mean()/2 + + # AD & HD + mesh_pred = trimesh.Trimesh(vertices=v_pred, faces=f_pred) + pts_pred = mesh_pred.sample(n_samples) + mesh_gt = trimesh.Trimesh(vertices=v_gt, faces=f_gt) + pts_gt = mesh_gt.sample(n_samples) + + _, P2G_dist, _ = trimesh.proximity.closest_point(mesh_pred, pts_gt) + _, G2P_dist, _ = trimesh.proximity.closest_point(mesh_gt, pts_pred) + + # average absolute distance + assd = ((P2G_dist.sum() + G2P_dist.sum()) / float(P2G_dist.size + G2P_dist.size)) + + # Hausdorff distance + hd = max(np.percentile(P2G_dist, 90), np.percentile(G2P_dist, 90)) + + return cd, assd, hd \ No newline at end of file diff --git a/PialNN/pialnn/1.0.0/weights/pialnn_model_lh_200epochs.pt b/PialNN/pialnn/1.0.0/weights/pialnn_model_lh_200epochs.pt new file mode 120000 index 00000000..24fbebb2 --- /dev/null +++ b/PialNN/pialnn/1.0.0/weights/pialnn_model_lh_200epochs.pt @@ -0,0 +1 @@ +../../../../.git/annex/objects/Gk/wQ/MD5E-s3963107--e00fa887f4a16a3e726a614bfa62202d.pt/MD5E-s3963107--e00fa887f4a16a3e726a614bfa62202d.pt \ No newline at end of file