Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/deep_neurographs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class GraphConfig:
----------
anisotropy : list[float], optional
Scaling factors applied to xyz coordinates to account for anisotropy
of microscope. The default is [1.0, 1.0, 1.0].
of microscope. Note this instance of "anisotropy" is only used while
reading fragments (i.e. swcs). The default is [1.0, 1.0, 1.0].
complex_bool : bool
Indication of whether to generate complex proposals, meaning proposals
between leaf and non-leaf nodes. The default is False.
Expand Down Expand Up @@ -74,12 +75,15 @@ class MLConfig:

Attributes
----------
anisotropy : list[float], optional
Scaling factors applied to xyz coordinates to account for anisotropy
of microscope. Note this instance of "anisotropy" is only used while
generating features. The default is [1.0, 1.0, 1.0].
batch_size : int
The number of samples processed in one batch during training or
inference. Default is 1000.
downsample_factor : int
Downsampling factor that accounts for which level in the image pyramid
the voxel coordinates must index into. The default is 0.
multiscale : int
Level in the image pyramid that voxel coordinates must index into.
high_threshold : float
A threshold value used for classification, above which predictions are
considered to be high-confidence. Default is 0.9.
Expand All @@ -89,14 +93,14 @@ class MLConfig:
Type of machine learning model to use. Default is "GraphNeuralNet".

"""

anisotropy: List[float] = field(default_factory=list)
batch_size: int = 2000
downsample_factor: int = 1
high_threshold: float = 0.9
lr: float = 1e-3
threshold: float = 0.6
model_type: str = "GraphNeuralNet"
multiscale: int = 1
n_epochs: int = 1000
threshold: float = 0.6
validation_split: float = 0.15
weight_decay: float = 1e-3

Expand Down
25 changes: 14 additions & 11 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
config,
device="cpu",
is_multimodal=False,
label_path=None,
labels_path=None,
log_runtimes=True,
save_to_s3_bool=False,
s3_dict=None,
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
...
is_multimodal : bool, optional
...
label_path : str, optional
labels_path : str, optional
Path to the segmentation assumed to be stored on a GCS bucket. The
default is None.
log_runtimes : bool, optional
Expand Down Expand Up @@ -132,11 +132,12 @@ def __init__(
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
anisotropy=self.ml_config.anisotropy,
batch_size=self.ml_config.batch_size,
confidence_threshold=self.ml_config.threshold,
device=device,
downsample_factor=self.ml_config.downsample_factor,
label_path=label_path,
multiscale=self.ml_config.multiscale,
labels_path=labels_path,
is_multimodal=is_multimodal,
)

Expand Down Expand Up @@ -474,11 +475,12 @@ def __init__(
model_path,
model_type,
radius,
anisotropy=[1.0, 1.0, 1.0],
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=None,
downsample_factor=1,
label_path=None,
multiscale=1,
labels_path=None,
is_multimodal=False
):
"""
Expand All @@ -501,9 +503,9 @@ def __init__(
confidence_threshold : float, optional
Threshold on acceptance probability for proposals. The default is
the global variable "CONFIDENCE_THRESHOLD".
downsample_factor : int, optional
Downsampling factor that accounts for which level in the image
pyramid the voxel coordinates must index into. The default is 0.
multiscale : int, optional
Level in the image pyramid that voxel coordinates must index into.
The default is 1.

Returns
-------
Expand All @@ -520,8 +522,9 @@ def __init__(
# Features
self.feature_generator = FeatureGenerator(
img_path,
downsample_factor,
label_path=label_path,
multiscale,
anisotropy=anisotropy,
labels_path=labels_path,
is_multimodal=is_multimodal
)

Expand Down
Loading
Loading