From 948ffac0cc8f5d7c0b78ea937545d85cba197463 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Mon, 23 Dec 2024 21:38:55 +0000 Subject: [PATCH 1/5] refactor: anistropy not hardcoded --- src/deep_neurographs/fragments_graph.py | 33 ++++++---- src/deep_neurographs/inference.py | 16 ++--- .../machine_learning/feature_generation.py | 62 ++++++++++--------- .../machine_learning/heterograph_datasets.py | 10 ++- src/deep_neurographs/utils/graph_util.py | 4 +- src/deep_neurographs/utils/img_util.py | 27 ++++---- src/deep_neurographs/utils/swc_util.py | 5 +- 7 files changed, 87 insertions(+), 70 deletions(-) diff --git a/src/deep_neurographs/fragments_graph.py b/src/deep_neurographs/fragments_graph.py index 28e58e52..4305a0b7 100644 --- a/src/deep_neurographs/fragments_graph.py +++ b/src/deep_neurographs/fragments_graph.py @@ -4,9 +4,9 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Implementation of subclass of Networkx.Graph called "FragmentsGraph". - -NOTE: SAVE LABEL UPDATES --- THERE IS A BUG IN FEATURE GENERATION +Implementation of subclass of Networkx.Graph called "FragmentsGraph" which is +a graph that is initialized by loading swc files (i.e. fragments) from a +predicted segmentation. """ import zipfile @@ -30,17 +30,23 @@ class FragmentsGraph(nx.Graph): """ - def __init__(self, img_bbox=None, node_spacing=1): + def __init__( + self, anisotropy=[1.0, 1.0, 1.0], img_bbox=None, node_spacing=1 + ): """ Initializes an instance of NeuroGraph. Parameters ---------- + anisotropy : ArrayLike, optional + Image to physical coordinates scaling factors to account for the + anisotropy of the microscope. The default is [1.0, 1.0, 1.0]. img_bbox : dict or None, optional Dictionary with the keys "min" and "max" which specify a bounding box in an image. The default is None. node_spacing : int, optional - Spacing (in microns) between nodes. The default is 1. + Physical spacing (in microns) between nodes in swcs. The default + is 1. Returns ------- @@ -49,6 +55,7 @@ def __init__(self, img_bbox=None, node_spacing=1): """ super(FragmentsGraph, self).__init__() # General class attributes + self.anisotropy = anisotropy self.leaf_kdtree = None self.node_cnt = 0 self.node_spacing = node_spacing @@ -908,8 +915,8 @@ def oriented_edge(self, edge, i, key="xyz"): def is_contained(self, node_or_xyz, buffer=0): if self.bbox: - coord = self.to_voxels(node_or_xyz) - return util.is_contained(self.bbox, coord, buffer=buffer) + voxel = self.to_voxels(node_or_xyz, self.anisotropy) + return util.is_contained(self.bbox, voxel, buffer=buffer) else: return True @@ -921,13 +928,17 @@ def branch_contained(self, xyz_list): else: return True - def to_voxels(self, node_or_xyz, shift=False): + def to_voxels(self, node_or_xyz, shift=np.array([0, 0, 0])): + # Get xyz coordinate shift = self.origin if shift else np.zeros((3)) if type(node_or_xyz) is int: - coord = img_util.to_voxels(self.nodes[node_or_xyz]["xyz"]) + xyz = self.nodes[node_or_xyz]["xyz"] else: - coord = img_util.to_voxels(node_or_xyz) - return coord - shift + xyz = node_or_xyz + + # Coordinate conversion + voxel = img_util.to_voxels(xyz, self.anisotropy) + return voxel - shift def is_leaf(self, i): """ diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 280c9c27..243d2dfd 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -164,7 +164,7 @@ def run(self, fragments_pointer): """ # Initializations - self.report_experiment() + self.log_experiment() self.write_metadata() t0 = time() @@ -181,22 +181,16 @@ def run(self, fragments_pointer): t, unit = util.time_writer(time() - t0) self.report(f"Total Runtime: {round(t, 4)} {unit}\n") - def run_schedule( - self, fragments_pointer, radius_schedule, save_all_rounds=False - ): + def run_schedule(self, fragments_pointer, radius_schedule): t0 = time() - self.report_experiment() + self.log_experiment() self.build_graph(fragments_pointer) for round_id, radius in enumerate(radius_schedule): self.report(f"--- Round {round_id + 1}: Radius = {radius} ---") round_id += 1 self.generate_proposals(radius) self.run_inference() - if save_all_rounds: - self.save_results(round_id=round_id) - - if not save_all_rounds: - self.save_results(round_id=round_id) + self.save_results(round_id=round_id) t, unit = util.time_writer(time() - t0) self.report(f"Total Runtime: {round(t, 4)} {unit}\n") @@ -433,7 +427,7 @@ def report(self, txt): self.log_handle.write(txt) self.log_handle.write("\n") - def report_experiment(self): + def log_experiment(self): self.report("\nExperiment Overview") self.report("-------------------------------------------------------") self.report(f"Sample_ID: {self.sample_id}") diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index dfcd4657..55e777e6 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -36,7 +36,8 @@ class FeatureGenerator: def __init__( self, img_path, - downsample_factor, + multiscale, + anisotropy=[1.0, 1.0, 1.0], label_path=None, is_multimodal=False, ): @@ -47,9 +48,11 @@ def __init__( ---------- img_path : str Path to the raw image assumed to be stored in a GCS bucket. - downsample_factor : int - Downsampling factor that accounts for which level in the image - pyramid the voxel coordinates must index into. + multiscale : int + Level in the image pyramid that voxel coordinates must index into. + anisotropy : ArrayLike, optional + Image to physical coordinates scaling factors to account for the + anisotropy of the microscope. The default is [1.0, 1.0, 1.0]. label_path : str, optional Path to the segmentation assumed to be stored on a GCS bucket. The default is None. @@ -62,11 +65,12 @@ def __init__( None """ - # Initialize instance attributes - self.downsample_factor = downsample_factor + # General instance attributes + self.anisotropy = anisotropy + self.multiscale = multiscale self.is_multimodal = is_multimodal - # Initialize image-based attributes + # Open images driver = "n5" if ".n5" in img_path else "zarr" self.img = img_util.open_tensorstore(img_path, driver=driver) if label_path: @@ -75,7 +79,7 @@ def __init__( self.labels = None # Set chunk shapes - self.img_patch_shape = self.set_patch_shape(downsample_factor) + self.img_patch_shape = self.set_patch_shape(multiscale) self.label_patch_shape = self.set_patch_shape(0) # Validate embedding requirements @@ -83,16 +87,14 @@ def __init__( raise("Must provide labels to generate image embeddings") @classmethod - def set_patch_shape(cls, downsample_factor): + def set_patch_shape(cls, multiscale): """ Adjusts the chunk shape by downsampling each dimension by a specified factor. Parameters ---------- - downsample_factor : int - The factor by which to downsample each dimension of the current - chunk shape. + None Returns ------- @@ -101,7 +103,7 @@ def set_patch_shape(cls, downsample_factor): factor. """ - return [s // 2 ** downsample_factor for s in cls.patch_shape] + return [s // 2 ** multiscale for s in cls.patch_shape] @classmethod def get_n_profile_points(cls): @@ -114,7 +116,7 @@ def run(self, neurograph, proposals_dict, radius): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph Graph that "proposals" belong to. proposals_dict : dict Dictionary that contains the items (1) "proposals" which are the @@ -154,7 +156,7 @@ def run_on_nodes(self, neurograph, computation_graph): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -173,7 +175,7 @@ def run_on_branches(self, neurograph, computation_graph): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -192,7 +194,7 @@ def run_on_proposals(self, neurograph, proposals, radius): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph NeuroGraph generated from a predicted segmentation. proposals : list[frozenset] List of proposals for which features will be generated. @@ -219,7 +221,7 @@ def node_skeletal(self, neurograph, computation_graph): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -248,8 +250,8 @@ def branch_skeletal(self, neurograph, computation_graph): Parameters ---------- - neurograph : NeuroGraph - NeuroGraph generated from a predicted segmentation. + neurograph : FragmentsGraph + Fragments graph that features are to be generated from. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -275,7 +277,7 @@ def proposal_skeletal(self, neurograph, proposals, radius): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph NeuroGraph generated from a predicted segmentation. proposals : list[frozenset] List of proposals for which features will be generated. @@ -311,7 +313,7 @@ def node_profiles(self, neurograph, computation_graph): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -349,7 +351,7 @@ def proposal_profiles(self, neurograph, proposals): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph Graph that "proposals" belong to. proposals : list[frozenset] List of proposals for which features will be generated. @@ -382,7 +384,7 @@ def proposal_patches(self, neurograph, proposals): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph Graph that "proposals" belong to. proposals : list[frozenset] List of proposals for which features will be generated. @@ -471,7 +473,7 @@ def transform_path(self, xyz_path): """ voxels = np.zeros((len(xyz_path), 3), dtype=int) for i, xyz in enumerate(xyz_path): - voxels[i] = img_util.to_voxels(xyz, self.downsample_factor) + voxels[i] = img_util.to_voxels(xyz, self.anisotropy, self.multiscale) return voxels def get_bbox(self, voxels, is_img=True): @@ -486,7 +488,7 @@ def get_bbox(self, voxels, is_img=True): def get_patch(self, labels, xyz_path, proposal): # Initializations center = np.mean(xyz_path, axis=0) - voxels = [img_util.to_voxels(xyz) for xyz in xyz_path] + voxels = [img_util.to_voxels(xyz, self.anisotropy) for xyz in xyz_path] # Read patches img_patch = self.read_img_patch(center) @@ -494,7 +496,7 @@ def get_patch(self, labels, xyz_path, proposal): return {proposal: np.stack([img_patch, label_patch], axis=0)} def read_img_patch(self, xyz_centroid): - center = img_util.to_voxels(xyz_centroid, self.downsample_factor) + center = img_util.to_voxels(xyz_centroid, self.anisotropy, self.multiscale) img_patch = img_util.read_tensorstore( self.img, center, self.img_patch_shape ) @@ -509,7 +511,7 @@ def read_label_patch(self, voxels, labels): def relabel(self, label_patch, voxels, labels): # Initializations n_points = self.get_n_profile_points() - scaling_factor = 2 ** self.downsample_factor + scaling_factor = 2 ** self.multiscale label_patch = zoom(label_patch, 1.0 / scaling_factor, order=0) for i, voxel in enumerate(voxels): voxels[i] = [v // scaling_factor for v in voxel] @@ -529,7 +531,7 @@ def get_leaf_path(neurograph, i): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph NeuroGraph generated from a predicted segmentation. i : int Leaf node in "neurograph". @@ -551,7 +553,7 @@ def get_branching_path(neurograph, i): Parameters ---------- - neurograph : NeuroGraph + neurograph : FragmentsGraph NeuroGraph generated from a predicted segmentation. i : int branching node in "neurograph". diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 361a4361..d2d86ddd 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -384,9 +384,17 @@ def set_edge_attrs(self, x_nodes, edge_type, idx_map): e1, e2 = self.data[edge_type].edge_index[:, i] v = node_intersection(idx_map, e1, e2) if v < 0: - attrs.append(torch.zeros(self.n_branch_features() + 1)) + attrs.append(np.zeros(self.n_branch_features() + 1)) else: attrs.append(x_nodes[v]) + + #print(edge_type, attrs[0].size()) + try: + np.array(attrs) + #print(edge_type, v, attrs) + except: + print(edge_type, v, attrs) + stop arrs = torch.tensor(np.array(attrs), dtype=DTYPE) self.data[edge_type].edge_attr = arrs diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 5ea7057e..5e56a29c 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -61,7 +61,7 @@ def __init__( Parameters ---------- anisotropy : list[float], optional - Scaling factors applied to xyz coordinates to account for + Scaling factors applied to xyz coordinates to account for the anisotropy of microscope. The default is [1.0, 1.0, 1.0]. min_size : float, optional Minimum path length of swc files which are stored as connected @@ -227,7 +227,7 @@ def clip_branches(self, graph, swc_id): if self.img_bbox: delete_nodes = set() for i in graph.nodes: - xyz = img_util.to_voxels(graph.nodes[i]["xyz"]) + xyz = img_util.to_voxels(graph.nodes[i]["xyz"], self.to_anisotropy) if not util.is_contained(self.img_bbox, xyz): delete_nodes.add(i) graph.remove_nodes_from(delete_nodes) diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 1d8ed8c7..5d939f07 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -16,7 +16,6 @@ from deep_neurographs.utils import util -ANISOTROPY = [0.748, 0.748, 1.0] SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "n5", "zarr"] @@ -290,15 +289,19 @@ def get_profile(img, spec, profile_id): # --- coordinate conversions --- -def to_world(voxel, shift=[0, 0, 0]): +def to_physical(voxel, anisotropy, shift=[0, 0, 0]): """ - Converts coordinates from voxels to world. + Converts a voxel coordinate to a physical coordinate by applying the + anisotropy scaling factors. Parameters ---------- - coord : numpy.ndarray + coord : ArrayLike Coordinate to be converted. - shift : list, optional + anisotropy : ArrayLike + Image to physical coordinates scaling factors to account for the + anisotropy of the microscope. + shift : ArrayLike, optional Shift to be applied to "coord". The default is [0, 0, 0]. Returns @@ -307,20 +310,20 @@ def to_world(voxel, shift=[0, 0, 0]): Converted coordinates. """ - return tuple([voxel[i] * ANISOTROPY[i] - shift[i] for i in range(3)]) + return tuple([voxel[i] * anisotropy[i] - shift[i] for i in range(3)]) -def to_voxels(xyz, downsample_factor=0): +def to_voxels(xyz, anisotropy, downsample_factor=0): """ Converts coordinates from world to voxel. Parameters ---------- - xyz : numpy.ndarray + xyz : ArrayLike xyz coordinate to be converted to voxels. - anisotropy : list, optional - Anisotropy to be applied to values of interest. The default is - [1.0, 1.0, 1.0]. + anisotropy : ArrayLike + Image to physical coordinates scaling factors to account for the + anisotropy of the microscope. downsample_factor : int, optional Downsampling factor that accounts for which level in the image pyramid the voxel coordinates must index into. The default is 0. @@ -332,7 +335,7 @@ def to_voxels(xyz, downsample_factor=0): """ downsample_factor = 1.0 / 2 ** downsample_factor - voxel = downsample_factor * (xyz / np.array(ANISOTROPY)) + voxel = downsample_factor * (xyz / np.array(anisotropy)) return np.round(voxel).astype(int) diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index 8003d21a..c759d54e 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -42,9 +42,8 @@ def __init__(self, anisotropy=[1.0, 1.0, 1.0], min_size=0): Parameters ---------- anisotropy : List[float], optional - Image to world scaling factors applied to xyz coordinates to - account for anisotropy of the microscope. The default is - [1.0, 1.0, 1.0]. + Image to physical coordinates scaling factors to account for the + anisotropy of the microscope. The default is [1.0, 1.0, 1.0]. min_size : int, optional Threshold on the number of nodes in swc file. Only swc files with more than "min_size" nodes are stored in "xyz_coords". The default From 4aaa853d0f63af92e69a966ce845f9eb90724d8c Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 9 Jan 2025 22:58:58 +0000 Subject: [PATCH 2/5] refactor: anisotropy not built-in --- src/deep_neurographs/generate_proposals.py | 2 +- src/deep_neurographs/utils/graph_util.py | 88 ++++++++-------------- src/deep_neurographs/utils/img_util.py | 26 ++++--- src/deep_neurographs/utils/util.py | 3 +- 4 files changed, 48 insertions(+), 71 deletions(-) diff --git a/src/deep_neurographs/generate_proposals.py b/src/deep_neurographs/generate_proposals.py index eb531ff1..75fd5d1c 100644 --- a/src/deep_neurographs/generate_proposals.py +++ b/src/deep_neurographs/generate_proposals.py @@ -4,7 +4,7 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Module used to generate edge proposals. +Module used to generate edge proposals for a fragments graph. """ diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 5e56a29c..d514cb5f 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -5,8 +5,20 @@ @email: anna.grim@alleninstitute.org -Routines for loading fragments and building a fragments_graph. +Overview +-------- +Code that reads and preprocesses neuron fragments stored as swc files, then +constructs a custom graph object called a "FragmentsGraph". + Graph Construction Algorithm: + 1. Read Neuron Fragments + to do... + + 2. Preprocess Fragments and Extract Irreducibles + to do... + + 3. Build FragmentsGraph + to do... Terminology ------------ @@ -31,12 +43,7 @@ from tqdm import tqdm from deep_neurographs import geometry -from deep_neurographs.utils import img_util, swc_util, util - -MIN_SIZE = 30 -NODE_SPACING = 1 -SMOOTH_BOOL = True -PRUNE_DEPTH = 20 +from deep_neurographs.utils import swc_util, util class GraphLoader: @@ -48,11 +55,11 @@ class GraphLoader: def __init__( self, anisotropy=[1.0, 1.0, 1.0], - min_size=MIN_SIZE, - node_spacing=NODE_SPACING, + min_size=30, + node_spacing=1, progress_bar=False, - prune_depth=PRUNE_DEPTH, - smooth_bool=SMOOTH_BOOL, + prune_depth=20, + smooth_bool=True, ): """ Builds a FragmentsGraph by reading swc files stored either on the @@ -60,24 +67,23 @@ def __init__( Parameters ---------- - anisotropy : list[float], optional - Scaling factors applied to xyz coordinates to account for the - anisotropy of microscope. The default is [1.0, 1.0, 1.0]. + anisotropy : List[float], optional + Image to physical coordinates scaling factors to account for the + anisotropy of the microscope. The default is [1.0, 1.0, 1.0]. min_size : float, optional Minimum path length of swc files which are stored as connected - components in the FragmentsGraph. The default is 30ums. + components in the FragmentsGraph. The default is 30.0 (microns). node_spacing : int, optional - Spacing (in microns) between nodes. The default is the global - variable "NODE_SPACING". + Spacing (in microns) between nodes. The default is 1. progress_bar : bool, optional Indication of whether to print out a progress bar while building graph. The default is True. prune_depth : int, optional Branches less than "prune_depth" microns are pruned if "prune" is - True. The default is the global variable "PRUNE_DEPTH". + True. The default is 20.0 (microns). smooth_bool : bool, optional Indication of whether to smooth branches from swc files. The - default is the global variable "SMOOTH". + default is True. Returns ------- @@ -90,12 +96,9 @@ def __init__( self.progress_bar = progress_bar self.prune_depth = prune_depth self.smooth_bool = smooth_bool - self.reader = swc_util.Reader(anisotropy, min_size) - def run( - self, fragments_pointer, img_patch_origin=None, img_patch_shape=None - ): + def run(self, fragments_pointer): """ Builds a FragmentsGraph by reading swc files stored either on the cloud or local machine, then extracting the irreducible components. @@ -105,12 +108,6 @@ def run( fragments_pointer : dict, list, str Pointer to swc files used to build an instance of FragmentsGraph, see "swc_util.Reader" for further documentation. - img_patch_origin : list[int], optional - An xyz coordinate which is the upper, left, front corner of the - image patch that contains the swc files. The default is None. - img_patch_shape : list[int], optional - Shape of the image patch which contains the swc files. The default - is None. Returns ------- @@ -120,12 +117,13 @@ def run( """ from deep_neurographs.fragments_graph import FragmentsGraph - # Load fragments and extract irreducibles - self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape) + # Step 1: Read Neuron Fragments swc_dicts = self.reader.load(fragments_pointer) + + # Step: Preprocess Fragments and Extract Irreducibles irreducibles = self.schedule_processes(swc_dicts) - # Build FragmentsGraph + # Step 3: Build FragmentsGraph fragments_graph = FragmentsGraph(node_spacing=self.node_spacing) while len(irreducibles): irreducible_set = irreducibles.pop() @@ -186,7 +184,7 @@ def get_irreducibles(self, swc_dict): Returns ------- - list + List[dict] List of dictionaries such that each is the set of irreducibles in a connected component of the graph corresponding to "swc_dict". @@ -194,7 +192,6 @@ def get_irreducibles(self, swc_dict): # Build dense graph swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"])))) graph, _ = swc_util.to_graph(swc_dict, set_attrs=True) - self.clip_branches(graph, swc_dict["swc_id"]) self.prune_branches(graph) # Extract irreducibles @@ -210,28 +207,6 @@ def get_irreducibles(self, swc_dict): irreducibles.append(result) return irreducibles - def clip_branches(self, graph, swc_id): - """ - Deletes all nodes from "graph" that are not contained in "img_bbox". - - Parameters - ---------- - graph : networkx.Graph - Graph to be searched - - Returns - ------- - None - - """ - if self.img_bbox: - delete_nodes = set() - for i in graph.nodes: - xyz = img_util.to_voxels(graph.nodes[i]["xyz"], self.to_anisotropy) - if not util.is_contained(self.img_bbox, xyz): - delete_nodes.add(i) - graph.remove_nodes_from(delete_nodes) - def prune_branches(self, graph): """ Prunes all short branches from "graph". A short branch is a path @@ -316,7 +291,6 @@ def get_component_irreducibles(self, graph, swc_dict): # Visit j attrs = upd_edge_attrs(swc_dict, attrs, j) if j in leafs or j in branchings: - # Check whether to smooth attrs["length"] = branch_length attrs = to_numpy(attrs) if self.smooth_bool: diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index 5d939f07..ea0f0954 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -4,7 +4,8 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Helper routines for working with images. + +Helper routines for processing images. """ @@ -313,29 +314,29 @@ def to_physical(voxel, anisotropy, shift=[0, 0, 0]): return tuple([voxel[i] * anisotropy[i] - shift[i] for i in range(3)]) -def to_voxels(xyz, anisotropy, downsample_factor=0): +def to_voxels(xyz, anisotropy, multiscale=0): """ Converts coordinates from world to voxel. Parameters ---------- xyz : ArrayLike - xyz coordinate to be converted to voxels. + Physical coordiante to be converted to a voxel coordinate. anisotropy : ArrayLike Image to physical coordinates scaling factors to account for the anisotropy of the microscope. - downsample_factor : int, optional - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. The default is 0. + multiscale : int, optional + Level in the image pyramid that the voxel coordinate must index into. + The default is 0. Returns ------- numpy.ndarray - Coordinates converted to voxels. + Voxel coordinate of the input. """ - downsample_factor = 1.0 / 2 ** downsample_factor - voxel = downsample_factor * (xyz / np.array(anisotropy)) + scaling_factor = 1.0 / 2 ** multiscale + voxel = scaling_factor * xyz / np.array(anisotropy) return np.round(voxel).astype(int) @@ -348,9 +349,10 @@ def init_bbox(origin, shape): Parameters ---------- origin : tuple[int] - Origin of bounding box which is assumed to be top, front, left corner. - shape : tuple[int] - Shape of bounding box. + Voxel coordinate of the origin of the bounding box, which is assumed + to be top-front-left corner. + shape : Tuple[int] + Shape of the bounding box. Returns ------- diff --git a/src/deep_neurographs/utils/util.py b/src/deep_neurographs/utils/util.py index 28fb3e72..9dd62863 100644 --- a/src/deep_neurographs/utils/util.py +++ b/src/deep_neurographs/utils/util.py @@ -4,7 +4,8 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -General helper routines for various tasks. + +Miscellaneous helper routines. """ From 570b85e276d189b0e941e21b64333e409c267624 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 10 Jan 2025 04:30:02 +0000 Subject: [PATCH 3/5] refactor: simplified graph loader --- src/deep_neurographs/utils/graph_util.py | 458 ++++++----------------- src/deep_neurographs/utils/swc_util.py | 48 ++- 2 files changed, 149 insertions(+), 357 deletions(-) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index d514cb5f..8f37e25f 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -8,33 +8,20 @@ Overview -------- Code that reads and preprocesses neuron fragments stored as swc files, then -constructs a custom graph object called a "FragmentsGraph". +constructs a custom graph object called a "FragmentsGraph" from them. Graph Construction Algorithm: 1. Read Neuron Fragments to do... - 2. Preprocess Fragments and Extract Irreducibles + 2. Extract Irreducibles to do... 3. Build FragmentsGraph to do... -Terminology ------------- - -Leaf: a node with degree 1. - -Branching: a node with degree > 2. - -Irreducibles: the irreducibles of a graph consists of 1) leaf nodes, -2) branching nodes, and 3) edges connecting (1) and (2). - -Branch: a sequence of nodes between two irreducible nodes. - """ -from collections import defaultdict from concurrent.futures import ProcessPoolExecutor, as_completed from random import sample @@ -118,10 +105,10 @@ def run(self, fragments_pointer): from deep_neurographs.fragments_graph import FragmentsGraph # Step 1: Read Neuron Fragments - swc_dicts = self.reader.load(fragments_pointer) + graph_list = self.reader.load(fragments_pointer) - # Step: Preprocess Fragments and Extract Irreducibles - irreducibles = self.schedule_processes(swc_dicts) + # Step: Extract Irreducibles + irreducibles = self.process_graphs(graph_list) # Step 3: Build FragmentsGraph fragments_graph = FragmentsGraph(node_spacing=self.node_spacing) @@ -130,81 +117,97 @@ def run(self, fragments_pointer): fragments_graph.add_component(irreducible_set) return fragments_graph - # --- Graph structure extraction --- - def schedule_processes(self, swc_dicts): + def process_graphs(self, graphs_list): """ - Gets irreducible components of each graph stored in "swc_dicts" by - setting up a parellelization scheme that sends each swc_dict to a CPU - and calls the subroutine "get_irreducibles". + Processes a list of graphs in parallel and extracts irreducible + subgraphs from each graph. Parameters ---------- - swc_dicts : list[dict] - List of dictionaries such that each entry contains the conents of - an swc file. + graphs_list : List[network.Graph] + List of graphs to be processed. Each graph is passed to the + "process_graph" method, which extracts the irreducible subgraphs + from each graph. Returns ------- - list[dict] - List of dictionaries such that each is the set of irreducibles in - a connected component of the graph corresponding to "swc_dicts". + List[dict] + List of irreducible subgraphs extracted from the input graphs. """ # Initializations if self.progress_bar: - pbar = tqdm(total=len(swc_dicts), desc="Extract Graphs") + pbar = tqdm(total=len(graphs_list), desc="Process Graphs") # Main - with ProcessPoolExecutor() as executor: + with ProcessPoolExecutor(max_workers=1) as executor: # Assign Processes - i = 0 - processes = [None] * len(swc_dicts) - while swc_dicts: - swc_dict = swc_dicts.pop() - processes[i] = executor.submit(self.get_irreducibles, swc_dict) - i += 1 + processes = list() + while graphs_list: + graph = graphs_list.pop() + processes.append( + executor.submit(self.extract_irreducibles, graph) + ) # Store results irreducibles = list() for process in as_completed(processes): - irreducibles.extend(process.result()) + result = process.result() + if result is not None: + irreducibles.append(result) if self.progress_bar: pbar.update(1) return irreducibles - def get_irreducibles(self, swc_dict): + def extract_irreducibles(self, graph): """ - Gets the irreducible components of graph stored in "swc_dict". This - routine also calls routines prunes short paths. + Gets the irreducible subgraph from the input graph. Parameters ---------- - swc_dict : dict - Contents of an swc file. + graph : dict + Graph that irreducible subgraph is to be extracted from. Returns ------- List[dict] - List of dictionaries such that each is the set of irreducibles in - a connected component of the graph corresponding to "swc_dict". + List of dictionaries such that each is the set of irreducibles + from the input graph. """ - # Build dense graph - swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"])))) - graph, _ = swc_util.to_graph(swc_dict, set_attrs=True) + irreducibles = None self.prune_branches(graph) - - # Extract irreducibles - irreducibles = list() - path_length = compute_path_length(graph) - if path_length > self.min_size and graph.number_of_nodes() > 1: - for nodes in nx.connected_components(graph): - if len(nodes) > 1: - result = self.get_component_irreducibles( - graph.subgraph(nodes), swc_dict - ) - if result: - irreducibles.append(result) + if compute_path_length(graph) > self.min_size: + # Extract irreducible nodes + leafs, branchings = get_irreducible_nodes(graph) + assert len(leafs) > 0, "No leaf nodes!" + + # Extract irreducible edges + edges = dict() + root = None + for (i, j) in nx.dfs_edges(graph, source=util.sample_once(leafs)): + # Check for start of irreducible edge + if root is None: + root = i + path = [i] + xyz_list = [graph.nodes[i]["xyz"]] + + # Check for end of irreducible edge + path.append(j) + xyz_list.append(graph.nodes[j]["xyz"]) + if j in leafs or j in branchings: + edges[(root, j)] = path + if self.smooth_bool: + graph = smooth_path(graph, path, xyz_list) + root = None + + # Set irreducible attributes + irreducibles = { + "leaf": set_node_attrs(graph, leafs), + "branching": set_node_attrs(graph, branchings), + "edge": set_edge_attrs(graph, edges), + "swc_id": graph.graph["swc_id"], + } return irreducibles def prune_branches(self, graph): @@ -250,72 +253,8 @@ def prune_branches(self, graph): graph.remove_nodes_from(branch[0:k]) break - def get_component_irreducibles(self, graph, swc_dict): - """ - Gets the irreducible components of "graph". - - Parameters - ---------- - graph : networkx.Graph - Graph to be searched. - swc_dict : dict - Dictionary used to build "graph". - Returns - ------- - dict - Dictionary containing irreducible components of "graph". - - """ - # Extract nodes - leafs, branchings = get_irreducible_nodes(graph) - assert len(leafs) > 0, "No leaf nodes!" - - # Extract edges - edges = dict() - nbs = defaultdict(list) - root = None - branch_length = 0 - for (i, j) in nx.dfs_edges(graph, source=util.sample_once(leafs)): - # Check if starting new or continuing current path - if root is None: - root = i - branch_length = 0 - attrs = init_edge_attrs(swc_dict, root) - - # Vist i - xyz_i = swc_dict["xyz"][swc_dict["idx"][i]] - xyz_j = swc_dict["xyz"][swc_dict["idx"][j]] - branch_length += geometry.dist(xyz_i, xyz_j) - - # Visit j - attrs = upd_edge_attrs(swc_dict, attrs, j) - if j in leafs or j in branchings: - attrs["length"] = branch_length - attrs = to_numpy(attrs) - if self.smooth_bool: - swc_dict, edges = smooth_branch( - swc_dict, attrs, edges, nbs, root, j - ) - else: - edges[(root, j)] = attrs - - # Finish - nbs[root].append(j) - nbs[j].append(root) - root = None - - # Output - irreducibles = { - "leaf": set_node_attrs(swc_dict, leafs), - "branching": set_node_attrs(swc_dict, branchings), - "edge": edges, - "swc_id": swc_dict["swc_id"], - } - return irreducibles - - -# --- Utils --- +# --- Extract Irreducibles --- def get_irreducible_nodes(graph): """ Gets irreducible nodes (i.e. leafs and branchings) of a graph. @@ -341,202 +280,7 @@ def get_irreducible_nodes(graph): return leafs, branchings -def smooth_branch(swc_dict, attrs, edges, nbs, root, j): - """ - Smoothes a branch then updates "swc_dict" and "edges" with the new xyz - coordinates of the branch end points. Note that this branch is an edge - in the irreducible graph being built. - - Parameters - ---------- - swc_dict : dict - Contents of an swc file. - attrs : dict - Attributes (from "swc_dict") of edge being smoothed. - edges : dict - Dictionary where the keys are edges in irreducible graph and values - are the corresponding attributes. - nbs : dict - Dictionary where the keys are nodes and values are the neighbors. - root : int - End point of branch to be smoothed. - j : int - End point of branch to be smoothed. - - Returns - ------- - dict, dict - Dictionaries that have been updated with respect to smoothed edges. - - """ - attrs["xyz"] = geometry.smooth_branch(attrs["xyz"], s=2) - swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, root, 0) - swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, j, -1) - edges[(root, j)] = attrs - return swc_dict, edges - - -def upd_xyz(swc_dict, attrs, edges, nbs, i, endpoint): - """ - Updates "swc_dict" and "edges" with the new xyz coordinates of the branch - end points. - - Parameters - ---------- - swc_dict : dict - Contents of an swc file. - attrs : dict - Attributes (from "swc_dict") of edge being smoothed. - edges : dict - Dictionary where the keys are edges in irreducible graph and values - are the corresponding attributes. - nbs : dict - Dictionary where the keys are nodes and values are the neighbors. - endpoint : int - End point of branch to be smoothed. - - Returns - ------- - dict - Updated with new xyz coordinates. - dict - Updated with new xyz coordinates. - - """ - idx = swc_dict["idx"][i] - if i in nbs.keys(): - for j in nbs[i]: - key = (i, j) if (i, j) in edges.keys() else (j, i) - edges = upd_endpoint_xyz( - edges, key, swc_dict["xyz"][idx], attrs["xyz"][endpoint] - ) - swc_dict["xyz"][idx] = attrs["xyz"][endpoint] - return swc_dict, edges - - -def upd_endpoint_xyz(edges, key, old_xyz, new_xyz): - """ - Updates "edges" with the new xyz coordinates of the branch - end points. - - Parameters - ---------- - edges : dict - Dictionary where the keys are edges in irreducible graph and values - are the corresponding attributes. - key : tuple - The edge id of the entry in "edges" which needs to be updated. - old_xyz : numpy.ndarray - Current xyz coordinate of end point. - new_xyz : numpy.ndarray - New xyz coordinate of end point. - - Returns - ------- - dict - Updated with new xyz coordinates. - - """ - if all(edges[key]["xyz"][0] == old_xyz): - edges[key]["xyz"][0] = new_xyz - elif all(edges[key]["xyz"][-1] == old_xyz): - edges[key]["xyz"][-1] = new_xyz - return edges - - -def init_edge_attrs(swc_dict, i): - """ - Initializes edge attribute dictionary with attributes from node "i" which - is an end point of the edge. Note: the following assertion error may be - useful: assert i in swc_dict["idx"].keys(), f"{swc_dict["swc_id"]} - {i}" - - Parameters - ---------- - swc_dict : dict - Contents of an swc file. - i : int - End point of edge and the swc attributes of this node are used to - initialize the edge attriubte dictionary. - - Returns - ------- - dict - Edge attribute dictionary. - - """ - j = swc_dict["idx"][i] - return {"radius": [swc_dict["radius"][j]], "xyz": [swc_dict["xyz"][j]]} - - -def upd_edge_attrs(swc_dict, attrs, i): - """ - Updates an edge attribute dictionary with attributes of node i. - - Parameters - ---------- - swc_dict : dict - Contents of an swc file. - attrs : dict - Attributes (from "swc_dict") of edge being updated. - i : int - Node of edge whose attributes will be added to "attrs". - - Returns - ------- - dict - Edge attribute dictionary. - - """ - swc_id = swc_dict["swc_id"] - assert i != -1, f"{swc_id} - {i}" - j = swc_dict["idx"][i] - attrs["radius"].append(swc_dict["radius"][j]) - attrs["xyz"].append(swc_dict["xyz"][j]) - return attrs - - -def get_edge_attr(graph, edge, attr): - """ - Gets the attribute "attr" of "edge". - - Parameters - ---------- - graph : networkx.Graph - Graph which "edge" belongs to. - edge : tuple - Edge to be queried for its attributes. - attr : str - Attribute to be queried. - - Returns - ------- - Attribute "attr" of "edge" - - """ - return graph.edges[edge][attr] - - -def to_numpy(attrs): - """ - Converts edge attributes from a list to NumPy array. - - Parameters - ---------- - attrs : dict - Dictionary containing attributes of some edge. - - Returns - ------- - dict - Updated edge attribute dictionary. - - """ - attrs["xyz"] = np.array(attrs["xyz"], dtype=np.float32) - attrs["radius"] = np.array(attrs["radius"], dtype=np.float16) - return attrs - - -def set_node_attrs(swc_dict, nodes): +def set_node_attrs(graph, nodes): """ Set node attributes by extracting values from "swc_dict". @@ -545,7 +289,7 @@ def set_node_attrs(swc_dict, nodes): swc_dict : dict Contents of an swc file. nodes : list - List of nodes to set attributes. + List of node ids to set attributes. Returns ------- @@ -554,47 +298,59 @@ def set_node_attrs(swc_dict, nodes): attributes extracted from "swc_dict". """ - attrs = dict() + node_attrs = dict() for i in nodes: - j = swc_dict["idx"][i] - attrs[i] = {"radius": swc_dict["radius"][j], "xyz": swc_dict["xyz"][j]} - return attrs + node_attrs[i] = { + "radius": graph.nodes[i]["radius"], "xyz": graph.nodes[i]["xyz"] + } + return node_attrs + + +def set_edge_attrs(graph, edges): + edge_attrs = dict() + for edge, path in edges.items(): + # Extract attributes + radius_list, xyz_list = list(), list() + for i in path: + radius_list.append(graph.nodes[i]["radius"]) + xyz_list.append(graph.nodes[i]["xyz"]) + + # Set attributes + edge_attrs[edge] = { + "length": 1000, + "radius": np.array(radius_list), + "xyz": np.array(xyz_list) + } + return edge_attrs -def upd_node_attrs(swc_dict, leafs, branchings, i): +# --- Miscellaneous --- +def smooth_path(graph, path, xyz_list): """ - Updates node attributes by extracting values from "swc_dict". + Smooths a given path on a graph by applying smoothing to the coordinates + of the nodes along the path and updating the graph with the smoothed + coordinates. Parameters ---------- - swc_dict : dict - Contents of an swc file that contains the smoothed xyz coordinates of - corresponding to "leafs" and "branchings". Note xyz coordinates are - smoothed during edge extraction. - leafs : dict - Dictionary where keys are leaf node ids and values are attribute - dictionaries. - branchings : dict - Dictionary where keys are branching node ids and values are attribute - dictionaries. - i : int - Node to be updated. + graph : networkx.Graph + Graph containing path to be smoothed. + path : List[int] + List of node indices representing the path in the graph. + xyz_list : List[Tuple[float]] + List of xyz coordinates of path in the graph to be smoothed. Returns ------- - dict - Updated dictionary if "i" was contained in "leafs.keys()". - dict - Updated dictionary if "i" was contained in "branchings.keys()". + networkx.Graph + Input graph with updated "xyz" attributes for the nodes from the input + path. """ - j = swc_dict["idx"][i] - upd_attrs = {"radius": swc_dict["radius"][j], "xyz": swc_dict["xyz"][j]} - if i in leafs: - leafs[i] = upd_attrs - else: - branchings[i] = upd_attrs - return leafs, branchings + smoothed_xyz_list = geometry.smooth_branch(np.array(xyz_list), s=2) + for i, xyz in zip(path, smoothed_xyz_list): + graph.nodes[i]["xyz"] = xyz + return graph def compute_path_length(graph): diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index c759d54e..ef020c1a 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -141,9 +141,9 @@ def load_from_local_path(self, path): """ content = util.read_txt(path) if len(content) > self.min_size - 10: - result = self.parse(content) - result["swc_id"] = util.get_swc_id(path) - return result + graph = self.parse(content) + graph.graph["swc_id"] = util.get_swc_id(path) + return graph else: return False @@ -268,14 +268,45 @@ def load_from_zipped_file(self, zip_file, path): """ content = util.read_zip(zip_file, path).splitlines() if len(content) > self.min_size - 10: - result = self.parse(content) - result["swc_id"] = util.get_swc_id(path) - return result + graph = self.parse(content) + graph.graph["swc_id"] = util.get_swc_id(path) + return graph else: return False # --- Process swc content --- def parse(self, content): + """ + Reads an swc file and builds an undirected graph from it. + + Parameters + ---------- + path : str + Path to swc file to be read. + + Returns + ------- + networkx.Graph + Graph built from an swc file. + + """ + graph = nx.Graph() + content, offset = self.process_content(content) + for line in content: + # Extract node info + parts = line.split() + child = int(parts[0]) + parent = int(parts[-1]) + radius = read_radius(parts[-2]) + xyz = self.read_xyz(parts[2:5], offset=offset) + + # Add node + graph.add_node(child, radius=radius, xyz=xyz) + if parent != -1: + graph.add_edge(parent, child) + return graph + + def parse_old(self, content): """ Parses an swc file to extract the content which is stored in a dict. Note that node_ids from swc are refactored to index from 0 to n-1 @@ -618,6 +649,11 @@ def set_radius(graph, i): # --- Miscellaneous --- +def read_radius(radius_str): + radius = float(radius_str) + return radius / 1000 if radius > 100 else radius + + def to_graph(swc_dict, swc_id=None, set_attrs=False): """ Converts an dictionary containing swc attributes to a graph. From ce416537e61257eb1d9304a013f19680b7588ad8 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Fri, 10 Jan 2025 05:38:26 +0000 Subject: [PATCH 4/5] refactor: simplified swc_util --- src/deep_neurographs/utils/graph_util.py | 22 +- src/deep_neurographs/utils/swc_util.py | 329 +++++++---------------- 2 files changed, 112 insertions(+), 239 deletions(-) diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 28e6efa4..7bf40c8a 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -178,11 +178,10 @@ def extract_irreducibles(self, graph): irreducibles = None self.prune_branches(graph) if compute_path_length(graph) > self.min_size: - # Extract irreducible nodes + # Irreducible nodes leafs, branchings = get_irreducible_nodes(graph) - assert len(leafs) > 0, "No leaf nodes!" - # Extract irreducible edges + # Irreducible edges edges = dict() root = None for (i, j) in nx.dfs_edges(graph, source=util.sample_once(leafs)): @@ -282,20 +281,20 @@ def get_irreducible_nodes(graph): def set_node_attrs(graph, nodes): """ - Set node attributes by extracting values from "swc_dict". + Set node attributes by extracting information from "graph". Parameters ---------- - swc_dict : dict - Contents of an swc file. + graph : networkx.Graph + Graph that contains "nodes". nodes : list List of node ids to set attributes. Returns ------- dict - Dictionary in which keys are node ids and values are a dictionary of - attributes extracted from "swc_dict". + Dictionary where keys are node ids and values are a dictionary of + attributes extracted from the input graph. """ node_attrs = dict() @@ -310,14 +309,17 @@ def set_edge_attrs(graph, edges): edge_attrs = dict() for edge, path in edges.items(): # Extract attributes + length = 0 radius_list, xyz_list = list(), list() - for i in path: + for idx, i in enumerate(path): radius_list.append(graph.nodes[i]["radius"]) xyz_list.append(graph.nodes[i]["xyz"]) + if idx > 0: + length += compute_dist(graph, path[idx], path[idx - 1]) # Set attributes edge_attrs[edge] = { - "length": 1000, + "length": length, "radius": np.array(radius_list), "xyz": np.array(xyz_list) } diff --git a/src/deep_neurographs/utils/swc_util.py b/src/deep_neurographs/utils/swc_util.py index ef020c1a..242b5884 100644 --- a/src/deep_neurographs/utils/swc_util.py +++ b/src/deep_neurographs/utils/swc_util.py @@ -5,7 +5,7 @@ @email: anna.grim@alleninstitute.org -Routines for working with swc files. +Routines for reading and writing swc files. """ @@ -46,8 +46,7 @@ def __init__(self, anisotropy=[1.0, 1.0, 1.0], min_size=0): anisotropy of the microscope. The default is [1.0, 1.0, 1.0]. min_size : int, optional Threshold on the number of nodes in swc file. Only swc files with - more than "min_size" nodes are stored in "xyz_coords". The default - is 0. + more than "min_size" nodes are processed. The default is 0. Returns ------- @@ -59,19 +58,19 @@ def __init__(self, anisotropy=[1.0, 1.0, 1.0], min_size=0): def load(self, swc_pointer): """ - Load data based on the type and format of the provided "swc_pointer". + Loads swc files specififed by "swc_pointer" and builds an attributed + graphs from them. Parameters ---------- swc_pointer : dict, list, str - Object that points to swcs to be read, see class documentation for - details. + Object that points to swc files to be read, see class documentation + for details. Returns ------- - List[dict] - List of dictionaries whose keys and values are the attribute name - and values from an swc file. + List[networkx.Graph] or networkx.Graph + Attributed graphs. """ if type(swc_pointer) is dict: @@ -88,55 +87,52 @@ def load(self, swc_pointer): return self.load_from_local_paths(paths) raise Exception("SWC Pointer is not Valid!") - # --- Load subroutines --- - def load_from_local_paths(self, swc_paths): + def load_from_local_paths(self, path_list): """ - Reads swc files from local machine, then returns either the xyz - coordinates or graphs. + Reads swc files from local machine and builds an attributed graph + from them. Paramters --------- - swc_paths : list - List of paths to swc files stored on the local machine. + path_list : List[str] + Paths to swc files on the local machine. Returns ------- - List[dict] - List of dictionaries whose keys and values are the attribute name - and values from an swc file. + List[networkx.Graph] + Attributed graphs. """ with ProcessPoolExecutor(max_workers=1) as executor: # Assign processes processes = list() - for path in swc_paths: + for path in path_list: processes.append( executor.submit(self.load_from_local_path, path) ) # Store results - swc_dicts = list() + graphs = list() for process in as_completed(processes): result = process.result() if result: - swc_dicts.append(result) - return swc_dicts + graphs.append(result) + return graphs def load_from_local_path(self, path): """ - Reads a single swc file from local machine, then returns either the - xyz coordinates or graphs. + Reads a single swc file on local machine and builds an attributed + graph from it. Paramters --------- path : str - Path to swc file stored on the local machine. + Path to swc file on the local machine. Returns ------- - List[dict] - List of dictionaries whose keys and values are the attribute name - and values from an swc file. + networkx.Graph + Attributed graph. """ content = util.read_txt(path) @@ -145,112 +141,105 @@ def load_from_local_path(self, path): graph.graph["swc_id"] = util.get_swc_id(path) return graph else: - return False + return None def load_from_local_zip(self, zip_path): """ - Reads swc files from zip on the local machine, then returns either the - xyz coordinates or graph. Note this routine is hard coded for computing - projected run length. + Reads swc files from a zip file and builds attributed graphs from + them. Paramters --------- - swc_paths : Container - If swc files are on local machine, list of paths to swc files where - each file corresponds to a neuron in the prediction. If swc files - are on cloud, then dict with keys "bucket_name" and "path". + zip_path : str + Path to zip file to be read. Returns ------- - dict - Dictionary that maps an swc_id to the the xyz coordinates read from - that swc file. + List[networkx.Graph] + Attributed graphs. """ with ZipFile(zip_path, "r") as zip_file: - swc_dicts = list() + graphs = list() swc_files = [f for f in zip_file.namelist() if f.endswith(".swc")] for f in tqdm(swc_files, desc="Loading Fragments"): - result = self.load_from_zipped_file(zip_file, f) + result = self.load_from_zip(zip_file, f) if result: - swc_dicts.append(result) - return swc_dicts + graphs.append(result) + return graphs def load_from_gcs(self, gcs_dict): """ - Reads swc files from zips on a GCS bucket. + Reads swc files from zips on a GCS bucket and builds attributed + graphs from them. Parameters ---------- gcs_dict : dict - Dictionary where keys are "bucket_name" and "path". + Dictionary with the keys "bucket_name" and "path" used to read + swcs from GCS bucket Returns ------- - dict - Dictionary that maps an swc_id to the the xyz coordinates read from - that swc file. + List[networkx.Graph] + Attributed graphs. """ - # Initializations bucket = storage.Client().bucket(gcs_dict["bucket_name"]) zip_paths = util.list_gcs_filenames(bucket, gcs_dict["path"], ".zip") - - # Main with ProcessPoolExecutor() as executor: # Assign processes processes = list() for path in tqdm(zip_paths, desc="Download SWCs"): - zip_content = bucket.blob(path).download_as_bytes() + zip_bytes = bucket.blob(path).download_as_bytes() processes.append( - executor.submit(self.load_from_cloud_zip, zip_content) + executor.submit(self.load_from_cloud_zip, zip_bytes) ) # Store results - swc_dicts = list() + graphs = list() for process in as_completed(processes): - swc_dicts.extend(process.result()) - return swc_dicts + graphs.extend(process.result()) + return graphs - def load_from_cloud_zip(self, zip_content): + def load_from_cloud_zip(self, zip_bytes): """ - Reads swc files from a zip that has been downloaded from a cloud - bucket. + Reads swc files from a zip and builds attributed graphs from them. Parameters ---------- - zip_content : ... - content of a zip file. + zip_bytes : bytes + Contents of a zip file in byte format. Returns ------- - dict - Dictionary that maps an swc_id to the the xyz coordinates read from - that swc file. + List[networkx.Graph] + Attributed graphs. """ - with ZipFile(BytesIO(zip_content)) as zip_file: + with ZipFile(BytesIO(zip_bytes)) as zip_file: with ThreadPoolExecutor() as executor: # Assign threads threads = list() - for f in util.list_files_in_zip(zip_content): + for f in util.list_files_in_zip(zip_bytes): threads.append( executor.submit( - self.load_from_zipped_file, zip_file, f + self.load_from_zip, zip_file, f ) ) # Process results - swc_dicts = list() + graphs = list() for thread in as_completed(threads): result = thread.result() if result: - swc_dicts.append(result) - return swc_dicts + graphs.append(result) + return graphs - def load_from_zipped_file(self, zip_file, path): + def load_from_zip(self, zip_file, path): """ - Reads swc file stored at "path" which points to a file in a zip. + Reads swc files at in a zip file at "path" and builds attributed + graphs from them. Parameters ---------- @@ -261,9 +250,8 @@ def load_from_zipped_file(self, zip_file, path): Returns ------- - dict - Dictionary that maps an swc_id to the the xyz coordinates or graph - read from that swc file. + networkx.Graph + Attributed graph. """ content = util.read_zip(zip_file, path).splitlines() @@ -274,10 +262,10 @@ def load_from_zipped_file(self, zip_file, path): else: return False - # --- Process swc content --- + # --- Process SWC Contents --- def parse(self, content): """ - Reads an swc file and builds an undirected graph from it. + Reads an swc file and builds an attributed graphs from it. Parameters ---------- @@ -297,7 +285,7 @@ def parse(self, content): parts = line.split() child = int(parts[0]) parent = int(parts[-1]) - radius = read_radius(parts[-2]) + radius = self.read_radius(parts[-2]) xyz = self.read_xyz(parts[2:5], offset=offset) # Add node @@ -306,54 +294,16 @@ def parse(self, content): graph.add_edge(parent, child) return graph - def parse_old(self, content): - """ - Parses an swc file to extract the content which is stored in a dict. - Note that node_ids from swc are refactored to index from 0 to n-1 - where n is the number of entries in the swc file. - - Parameters - ---------- - content : List[str] - List of entries from an swc file. - - Returns - ------- - dict - Dictionaries whose keys and values are the attribute name - and values from an swc file. - - """ - # Parse swc content - content, offset = self.process_content(content) - swc_dict = { - "id": np.zeros((len(content)), dtype=np.int32), - "radius": np.zeros((len(content)), dtype=np.float32), - "pid": np.zeros((len(content)), dtype=np.int32), - "xyz": np.zeros((len(content), 3), dtype=np.float32), - } - for i, line in enumerate(content): - parts = line.split() - swc_dict["id"][i] = parts[0] - swc_dict["radius"][i] = float(parts[-2]) - swc_dict["pid"][i] = parts[-1] - swc_dict["xyz"][i] = self.read_xyz(parts[2:5], offset) - - # Check whether radius is in nanometers - if swc_dict["radius"][0] > 100: - swc_dict["radius"] /= 1000 - return swc_dict - def process_content(self, content): """ - Processes lines of text from a content source, extracting an offset - value and returning the remaining content starting from the line - immediately after the last commented line. + Processes lines of text from an swc file by iterating over commented + lines to extract offset (if present) and finds the line after the last + commented line. Parameters ---------- content : List[str] - List of strings where each string represents a line of text. + List of strings that represent a line of a text file. Returns ------- @@ -393,19 +343,38 @@ def read_xyz(self, xyz_str, offset=[0.0, 0.0, 0.0]): xyz[i] = self.anisotropy[i] * (float(xyz_str[i]) + offset[i]) return xyz + def read_radius(self, radius_str): + """ + Converts a radius string to a float and adjusts it if the value is in + nanometers. + + Parameters + ---------- + radius_str : str + A string representing the radius value. + + Returns + ------- + float + Radius. + + """ + radius = float(radius_str) + return radius / 1000 if radius > 100 else radius + # --- Write --- def write(path, content, color=None): """ - Write content to a specified file in a format based on the type o - f content. + Writes an swc from the given "content" which is either a list of entries + or a graph. Parameters ---------- path : str - File path where the content will be written. - content : list, dict, nx.Graph - The content to be written. + Path where the content is to be written. + content : List[str] or networkx.Graph + Content of swc file to be written. color : str, optional Color of swc to be written. The default is None. @@ -416,8 +385,6 @@ def write(path, content, color=None): """ if type(content) is list: write_list(path, content, color=color) - elif type(content) is dict: - write_dict(path, content, color=color) elif type(content) is nx.Graph: write_graph(path, content, color=color) else: @@ -432,8 +399,8 @@ def write_list(path, entry_list, color=None): ---------- path : str Path that swc will be written to. - entry_list : list[str] - List of entries that will be written to an swc file. + entry_list : List[str] + List of entries to be written to an swc file. color : str, optional Color of swc to be written. The default is None. @@ -443,7 +410,7 @@ def write_list(path, entry_list, color=None): """ with open(path, "w") as f: - # Preamble + # Comments if color is not None: f.write("# COLOR " + color) else: @@ -454,33 +421,10 @@ def write_list(path, entry_list, color=None): f.write("\n" + entry) -def write_dict(path, swc_dict, color=None): - """ - Writes the dictionary to an swc file. - - Parameters - ---------- - path : str - Path that swc will be written to. - swc_dict : dict - Dictionaries whose keys and values are the attribute name and values - from an swc file. - color : str, optional - Color of swc to be written. The default is None. - - Returns - ------- - None - - """ - graph, _ = to_graph(swc_dict, set_attrs=True) - write_graph(path, graph, color=color) - - def write_graph(path, graph, color=None): """ - Makes a list of entries to be written in an swc file. This routine assumes - that "graph" has a single connected components. + Writes a graph to an swc file. This routine assumes that "graph" has a + single connected component. Parameters ---------- @@ -491,8 +435,7 @@ def write_graph(path, graph, color=None): Returns ------- - List[str] - List of swc file entries to be written. + None """ node_to_idx = {-1: -1} @@ -646,75 +589,3 @@ def set_radius(graph, i): except ValueError: radius = 1.0 return radius - - -# --- Miscellaneous --- -def read_radius(radius_str): - radius = float(radius_str) - return radius / 1000 if radius > 100 else radius - - -def to_graph(swc_dict, swc_id=None, set_attrs=False): - """ - Converts an dictionary containing swc attributes to a graph. - - Parameters - ---------- - swc_dict : dict - Dictionaries whose keys and values are the attribute name and values - from an swc file. - swc_id : str, optional - Identifier that dictionary was generated from. The default is None. - set_attrs : bool, optional - Indication of whether to set attributes. The default is False. - - Returns - ------- - networkx.Graph - Graph generated from "swc_dict". - - """ - graph = nx.Graph(graph_id=swc_id) - graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:])) - if set_attrs: - xyz = swc_dict["xyz"] - if type(swc_dict["xyz"]) is np.ndarray: - xyz = util.numpy_to_hashable(swc_dict["xyz"]) - graph = __add_attributes(swc_dict, graph) - xyz_to_node = dict(zip(xyz, swc_dict["id"])) - return graph, xyz_to_node - return graph - - -def __add_attributes(swc_dict, graph): - """ - Adds node attributes to a NetworkX graph based on information from - "swc_dict". - - Parameters: - ---------- - swc_dict : dict - A dictionary containing SWC data. It must have the following keys: - - "id": A list of node identifiers (unique for each node). - - "xyz": A list of 3D coordinates (x, y, z) for each node. - - "radius": A list of radii for each node. - - graph : networkx.Graph - A NetworkX graph object to which the attributes will be added. - The graph must contain nodes that correspond to the IDs in - "swc_dict["id"]". - - Returns: - ------- - networkx.Graph - The modified graph with added node attributes for each node. - - """ - attrs = dict() - for idx, node in enumerate(swc_dict["id"]): - attrs[node] = { - "xyz": swc_dict["xyz"][idx], - "radius": swc_dict["radius"][idx], - } - nx.set_node_attributes(graph, attrs) - return graph From 5f34c41f1249660d42431866f8cf5ffb29139817 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 16 Jan 2025 07:36:22 +0000 Subject: [PATCH 5/5] refactor: image reader class --- src/deep_neurographs/config.py | 18 +- src/deep_neurographs/inference.py | 25 +- .../machine_learning/feature_generation.py | 231 ++++++----- src/deep_neurographs/utils/graph_util.py | 10 +- src/deep_neurographs/utils/img_util.py | 374 ++++++++++-------- 5 files changed, 365 insertions(+), 293 deletions(-) diff --git a/src/deep_neurographs/config.py b/src/deep_neurographs/config.py index e7a8e827..df78462b 100644 --- a/src/deep_neurographs/config.py +++ b/src/deep_neurographs/config.py @@ -23,7 +23,8 @@ class GraphConfig: ---------- anisotropy : list[float], optional Scaling factors applied to xyz coordinates to account for anisotropy - of microscope. The default is [1.0, 1.0, 1.0]. + of microscope. Note this instance of "anisotropy" is only used while + reading fragments (i.e. swcs). The default is [1.0, 1.0, 1.0]. complex_bool : bool Indication of whether to generate complex proposals, meaning proposals between leaf and non-leaf nodes. The default is False. @@ -74,12 +75,15 @@ class MLConfig: Attributes ---------- + anisotropy : list[float], optional + Scaling factors applied to xyz coordinates to account for anisotropy + of microscope. Note this instance of "anisotropy" is only used while + generating features. The default is [1.0, 1.0, 1.0]. batch_size : int The number of samples processed in one batch during training or inference. Default is 1000. - downsample_factor : int - Downsampling factor that accounts for which level in the image pyramid - the voxel coordinates must index into. The default is 0. + multiscale : int + Level in the image pyramid that voxel coordinates must index into. high_threshold : float A threshold value used for classification, above which predictions are considered to be high-confidence. Default is 0.9. @@ -89,14 +93,14 @@ class MLConfig: Type of machine learning model to use. Default is "GraphNeuralNet". """ - + anisotropy: List[float] = field(default_factory=list) batch_size: int = 2000 - downsample_factor: int = 1 high_threshold: float = 0.9 lr: float = 1e-3 - threshold: float = 0.6 model_type: str = "GraphNeuralNet" + multiscale: int = 1 n_epochs: int = 1000 + threshold: float = 0.6 validation_split: float = 0.15 weight_decay: float = 1e-3 diff --git a/src/deep_neurographs/inference.py b/src/deep_neurographs/inference.py index 243d2dfd..1e559a82 100644 --- a/src/deep_neurographs/inference.py +++ b/src/deep_neurographs/inference.py @@ -68,7 +68,7 @@ def __init__( config, device="cpu", is_multimodal=False, - label_path=None, + labels_path=None, log_runtimes=True, save_to_s3_bool=False, s3_dict=None, @@ -98,7 +98,7 @@ def __init__( ... is_multimodal : bool, optional ... - label_path : str, optional + labels_path : str, optional Path to the segmentation assumed to be stored on a GCS bucket. The default is None. log_runtimes : bool, optional @@ -132,11 +132,12 @@ def __init__( self.model_path, self.ml_config.model_type, self.graph_config.search_radius, + anisotropy=self.ml_config.anisotropy, batch_size=self.ml_config.batch_size, confidence_threshold=self.ml_config.threshold, device=device, - downsample_factor=self.ml_config.downsample_factor, - label_path=label_path, + multiscale=self.ml_config.multiscale, + labels_path=labels_path, is_multimodal=is_multimodal, ) @@ -474,11 +475,12 @@ def __init__( model_path, model_type, radius, + anisotropy=[1.0, 1.0, 1.0], batch_size=BATCH_SIZE, confidence_threshold=CONFIDENCE_THRESHOLD, device=None, - downsample_factor=1, - label_path=None, + multiscale=1, + labels_path=None, is_multimodal=False ): """ @@ -501,9 +503,9 @@ def __init__( confidence_threshold : float, optional Threshold on acceptance probability for proposals. The default is the global variable "CONFIDENCE_THRESHOLD". - downsample_factor : int, optional - Downsampling factor that accounts for which level in the image - pyramid the voxel coordinates must index into. The default is 0. + multiscale : int, optional + Level in the image pyramid that voxel coordinates must index into. + The default is 1. Returns ------- @@ -520,8 +522,9 @@ def __init__( # Features self.feature_generator = FeatureGenerator( img_path, - downsample_factor, - label_path=label_path, + multiscale, + anisotropy=anisotropy, + labels_path=labels_path, is_multimodal=is_multimodal ) diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 55e777e6..a4bcec27 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -8,8 +8,8 @@ inference. Conventions: - (1) "xyz" refers to a real world coordinate such as those from an swc file - (2) "voxel" refers to an voxel coordinate in a whole exaspim image. + (1) "xyz" refers to a physical coordinate such as those from an swc file + (2) "voxel" refers to an voxel coordinate in a whole-brain image. """ @@ -21,6 +21,7 @@ from deep_neurographs import geometry from deep_neurographs.utils import img_util, util +from deep_neurographs.utils.img_util import TensorStoreReader, ZarrReader class FeatureGenerator: @@ -38,7 +39,7 @@ def __init__( img_path, multiscale, anisotropy=[1.0, 1.0, 1.0], - label_path=None, + labels_path=None, is_multimodal=False, ): """ @@ -53,7 +54,7 @@ def __init__( anisotropy : ArrayLike, optional Image to physical coordinates scaling factors to account for the anisotropy of the microscope. The default is [1.0, 1.0, 1.0]. - label_path : str, optional + labels_path : str, optional Path to the segmentation assumed to be stored on a GCS bucket. The default is None. is_multimodal : bool, optional @@ -65,26 +66,22 @@ def __init__( None """ + # Sanity check + if is_multimodal and not labels_path: + raise("Must provide label mask to use multimodal model!") + # General instance attributes self.anisotropy = anisotropy self.multiscale = multiscale self.is_multimodal = is_multimodal # Open images - driver = "n5" if ".n5" in img_path else "zarr" - self.img = img_util.open_tensorstore(img_path, driver=driver) - if label_path: - self.labels = img_util.open_tensorstore(label_path) - else: - self.labels = None - - # Set chunk shapes + self.img_reader = self.init_img_reader(img_path, "zarr") self.img_patch_shape = self.set_patch_shape(multiscale) - self.label_patch_shape = self.set_patch_shape(0) - - # Validate embedding requirements - if self.is_multimodal and not label_path: - raise("Must provide labels to generate image embeddings") + if labels_path is not None: + driver = "neuroglancer_precomputed" + self.labels_reader = self.init_img_reader(labels_path, driver) + self.label_patch_shape = self.set_patch_shape(0) @classmethod def set_patch_shape(cls, multiscale): @@ -98,21 +95,54 @@ def set_patch_shape(cls, multiscale): Returns ------- - list - Adjusted chunk shape with each dimension reduced by the downsample - factor. + List[int] + Chunk shape with each dimension reduced by the multiscale. """ return [s // 2 ** multiscale for s in cls.patch_shape] @classmethod def get_n_profile_points(cls): + """ + Gets the number of points on an image profile. + + Parameters + ---------- + None + + Returns + ------- + int + Number of points on an image profile. + + """ return cls.n_profile_points + def init_img_reader(self, img_path, driver=None): + """ + Initializes an image reader. + + Parameters + ---------- + img_path : str + Path to where the image is located. + driver : str, optional + Storage driver needed to read image. The default is "zarr". + + Returns + ------- + ImageReader + Image reader. + + """ + if "s3" in img_path: + return ZarrReader(img_path) + else: + return TensorStoreReader(img_path, driver) + def run(self, neurograph, proposals_dict, radius): """ - Generates feature vectors for nodes, edges, and - proposals in a graph. + Generates feature vectors for nodes, edges, and proposals in a graph. Parameters ---------- @@ -157,7 +187,7 @@ def run_on_nodes(self, neurograph, computation_graph): Parameters ---------- neurograph : FragmentsGraph - NeuroGraph generated from a predicted segmentation. + FragmentsGraph generated from a predicted segmentation. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -176,7 +206,7 @@ def run_on_branches(self, neurograph, computation_graph): Parameters ---------- neurograph : FragmentsGraph - NeuroGraph generated from a predicted segmentation. + FragmentsGraph generated from a predicted segmentation. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -195,7 +225,7 @@ def run_on_proposals(self, neurograph, proposals, radius): Parameters ---------- neurograph : FragmentsGraph - NeuroGraph generated from a predicted segmentation. + FragmentsGraph generated from a predicted segmentation. proposals : list[frozenset] List of proposals for which features will be generated. radius : float @@ -215,14 +245,14 @@ def run_on_proposals(self, neurograph, proposals, radius): return features # -- Skeletal Features -- - def node_skeletal(self, neurograph, computation_graph): + def node_skeletal(self, fragments_graph, computation_graph): """ Generates skeleton-based features for nodes in "computation_graph". Parameters ---------- - neurograph : FragmentsGraph - NeuroGraph generated from a predicted segmentation. + fragments_graph : FragmentsGraph + FragmentsGraph generated from a predicted segmentation. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -236,21 +266,21 @@ def node_skeletal(self, neurograph, computation_graph): for i in computation_graph.nodes: node_skeletal_features[i] = np.concatenate( ( - neurograph.degree[i], - neurograph.nodes[i]["radius"], - len(neurograph.nodes[i]["proposals"]), + fragments_graph.degree[i], + fragments_graph.nodes[i]["radius"], + len(fragments_graph.nodes[i]["proposals"]), ), axis=None, ) return node_skeletal_features - def branch_skeletal(self, neurograph, computation_graph): + def branch_skeletal(self, fragments_graph, computation_graph): """ Generates skeleton-based features for edges in "computation_graph". Parameters ---------- - neurograph : FragmentsGraph + fragments_graph : FragmentsGraph Fragments graph that features are to be generated from. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -262,24 +292,24 @@ def branch_skeletal(self, neurograph, computation_graph): """ branch_skeletal_features = dict() - for edge in neurograph.edges: + for edge in fragments_graph.edges: branch_skeletal_features[frozenset(edge)] = np.array( [ - np.mean(neurograph.edges[edge]["radius"]), - min(neurograph.edges[edge]["length"], 500) / 500, + np.mean(fragments_graph.edges[edge]["radius"]), + min(fragments_graph.edges[edge]["length"], 500) / 500, ], ) return branch_skeletal_features - def proposal_skeletal(self, neurograph, proposals, radius): + def proposal_skeletal(self, fragments_graph, proposals, radius): """ Generates skeleton-based features for "proposals". Parameters ---------- - neurograph : FragmentsGraph - NeuroGraph generated from a predicted segmentation. - proposals : list[frozenset] + fragments_graph : FragmentsGraph + Graph generated from a predicted segmentation. + proposals : List[Frozenset[int]] List of proposals for which features will be generated. radius : float Search radius used to generate proposals. @@ -294,27 +324,27 @@ def proposal_skeletal(self, neurograph, proposals, radius): for proposal in proposals: proposal_skeletal_features[proposal] = np.concatenate( ( - neurograph.proposal_length(proposal) / radius, - neurograph.n_nearby_leafs(proposal, radius), - neurograph.proposal_radii(proposal), - neurograph.proposal_directionals(proposal, 16), - neurograph.proposal_directionals(proposal, 32), - neurograph.proposal_directionals(proposal, 64), - neurograph.proposal_directionals(proposal, 128), + fragments_graph.proposal_length(proposal) / radius, + fragments_graph.n_nearby_leafs(proposal, radius), + fragments_graph.proposal_radii(proposal), + fragments_graph.proposal_directionals(proposal, 16), + fragments_graph.proposal_directionals(proposal, 32), + fragments_graph.proposal_directionals(proposal, 64), + fragments_graph.proposal_directionals(proposal, 128), ), axis=None, ) return proposal_skeletal_features # --- Image features --- - def node_profiles(self, neurograph, computation_graph): + def node_profiles(self, fragments_graph, computation_graph): """ Generates image profiles for nodes in "computation_graph". Parameters ---------- - neurograph : FragmentsGraph - NeuroGraph generated from a predicted segmentation. + fragments_graph : FragmentsGraph + Graph generated from a predicted segmentation. computation_graph : networkx.Graph Graph used by GNN to classify proposals. @@ -329,10 +359,10 @@ def node_profiles(self, neurograph, computation_graph): threads = computation_graph.number_of_nodes() * [None] for idx, i in enumerate(computation_graph.nodes): # Get profile path - if neurograph.is_leaf(i): - xyz_path = self.get_leaf_path(neurograph, i) + if fragments_graph.is_leaf(i): + xyz_path = self.get_leaf_path(fragments_graph, i) else: - xyz_path = self.get_branching_path(neurograph, i) + xyz_path = self.get_branching_path(fragments_graph, i) # Assign threads[idx] = executor.submit( @@ -345,15 +375,15 @@ def node_profiles(self, neurograph, computation_graph): node_profile_features.update(thread.result()) return node_profile_features - def proposal_profiles(self, neurograph, proposals): + def proposal_profiles(self, fragments_graph, proposals): """ Generates an image intensity profile along the proposal. Parameters ---------- - neurograph : FragmentsGraph + fragments_graph : FragmentsGraph Graph that "proposals" belong to. - proposals : list[frozenset] + proposals : List[Frozenset[int]] List of proposals for which features will be generated. Returns @@ -368,7 +398,7 @@ def proposal_profiles(self, neurograph, proposals): threads = list() for p in proposals: n_points = self.get_n_profile_points() - xyz_1, xyz_2 = neurograph.proposal_xyz(p) + xyz_1, xyz_2 = fragments_graph.proposal_xyz(p) xyz_path = geometry.make_line(xyz_1, xyz_2, n_points) threads.append(executor.submit(self.get_profile, xyz_path, p)) @@ -378,13 +408,13 @@ def proposal_profiles(self, neurograph, proposals): profiles.update(thread.result()) return profiles - def proposal_patches(self, neurograph, proposals): + def proposal_patches(self, fragments_graph, proposals): """ Generates an image intensity profile along the proposal. Parameters ---------- - neurograph : FragmentsGraph + fragments_graph : FragmentsGraph Graph that "proposals" belong to. proposals : list[frozenset] List of proposals for which features will be generated. @@ -400,8 +430,8 @@ def proposal_patches(self, neurograph, proposals): # Assign threads threads = list() for p in proposals: - labels = neurograph.proposal_labels(p) - xyz_path = np.vstack(neurograph.proposal_xyz(p)) + labels = fragments_graph.proposal_labels(p) + xyz_path = np.vstack(fragments_graph.proposal_xyz(p)) threads.append( executor.submit(self.get_patch, labels, xyz_path, p) ) @@ -431,7 +461,7 @@ def get_profile(self, xyz_path, profile_id): profile. """ - profile = img_util.read_profile(self.img, self.get_spec(xyz_path)) + profile = self.img_reader.read_profile(self.get_spec(xyz_path)) profile.extend(list(util.get_avg_std(profile))) return {profile_id: profile} @@ -451,31 +481,11 @@ def get_spec(self, xyz_path): Specifications needed to compute a profile. """ - voxels = self.transform_path(xyz_path) + voxels = np.vstack([self.to_voxels(xyz) for xyz in xyz_path]) bbox = self.get_bbox(voxels) profile_path = geometry.shift_path(voxels, bbox["min"]) return {"bbox": bbox, "profile_path": profile_path} - def transform_path(self, xyz_path): - """ - Converts "xyz_path" from world to voxel coordinates. - - Parameters - ---------- - xyz_path : numpy.ndarray - xyz coordinates of a profile path. - - Returns - ------- - numpy.ndarray - Voxel coordinates of given path. - - """ - voxels = np.zeros((len(xyz_path), 3), dtype=int) - for i, xyz in enumerate(xyz_path): - voxels[i] = img_util.to_voxels(xyz, self.anisotropy, self.multiscale) - return voxels - def get_bbox(self, voxels, is_img=True): center = np.round(np.mean(voxels, axis=0)).astype(int) shape = self.img_patch_shape if is_img else self.label_patch_shape @@ -488,17 +498,16 @@ def get_bbox(self, voxels, is_img=True): def get_patch(self, labels, xyz_path, proposal): # Initializations center = np.mean(xyz_path, axis=0) - voxels = [img_util.to_voxels(xyz, self.anisotropy) for xyz in xyz_path] + voxels = [self.to_voxels(xyz) for xyz in xyz_path] # Read patches img_patch = self.read_img_patch(center) label_patch = self.read_label_patch(voxels, labels) return {proposal: np.stack([img_patch, label_patch], axis=0)} - def read_img_patch(self, xyz_centroid): - center = img_util.to_voxels(xyz_centroid, self.anisotropy, self.multiscale) + def read_img_patch(self, xyz): img_patch = img_util.read_tensorstore( - self.img, center, self.img_patch_shape + self.img, self.voxels(xyz), self.img_patch_shape ) return img_util.normalize(img_patch) @@ -511,8 +520,7 @@ def read_label_patch(self, voxels, labels): def relabel(self, label_patch, voxels, labels): # Initializations n_points = self.get_n_profile_points() - scaling_factor = 2 ** self.multiscale - label_patch = zoom(label_patch, 1.0 / scaling_factor, order=0) + label_patch = zoom(label_patch, 1.0 / 2 ** self.multiscale, order=0) for i, voxel in enumerate(voxels): voxels[i] = [v // scaling_factor for v in voxel] @@ -523,18 +531,20 @@ def relabel(self, label_patch, voxels, labels): line = geometry.make_line(voxels[0], voxels[-1], n_points) return geometry.fill_path(relabel_patch, line, val=-1) + def to_voxels(self, xyz): + return img_util.to_voxels(xyz, self.anisotropy, self.multiscale) # --- Profile utils --- -def get_leaf_path(neurograph, i): +def get_leaf_path(fragments_graph, i): """ Gets path that profile will be computed over for the leaf node "i". Parameters ---------- - neurograph : FragmentsGraph - NeuroGraph generated from a predicted segmentation. + fragments_graph : FragmentsGraph + Graph that node belongs to. i : int - Leaf node in "neurograph". + Leaf node in "fragments_graph". Returns ------- @@ -542,21 +552,21 @@ def get_leaf_path(neurograph, i): Voxel coordinates that profile is generated from. """ - j = neurograph.leaf_neighbor(i) - xyz_path = neurograph.oriented_edge((i, j), i) + j = fragments_graph.leaf_neighbor(i) + xyz_path = fragments_graph.oriented_edge((i, j), i) return geometry.truncate_path(xyz_path) -def get_branching_path(neurograph, i): +def get_branching_path(fragments_graph, i): """ Gets path that profile will be computed over for the branching node "i". Parameters ---------- - neurograph : FragmentsGraph - NeuroGraph generated from a predicted segmentation. + fragments_graph : FragmentsGraph + Graph generated from a predicted segmentation. i : int - branching node in "neurograph". + Branching node in "fragments_graph". Returns ------- @@ -564,9 +574,9 @@ def get_branching_path(neurograph, i): Voxel coordinates that profile is generated from. """ - j_1, j_2 = tuple(neurograph.neighbors(i)) - voxels_1 = geometry.truncate_path(neurograph.oriented_edge((i, j_1), i)) - voxles_2 = geometry.truncate_path(neurograph.oriented_edge((i, j_2), i)) + j1, j2 = tuple(fragments_graph.neighbors(i)) + voxels_1 = geometry.truncate_path(fragments_graph.oriented_edge((i, j1), i)) + voxles_2 = geometry.truncate_path(fragments_graph.oriented_edge((i, j2), i)) return np.vstack([np.flip(voxels_1, axis=0), voxles_2]) @@ -594,29 +604,16 @@ def get_patches_matrix(patches, id_to_idx): return x -def stack_matrices(neurographs, features, blocks): - x, y = None, None - for block_id in blocks: - x_i, y_i, _ = get_matrix(features[block_id]) - if x is None: - x = deepcopy(x_i) - y = deepcopy(y_i) - else: - x = np.concatenate((x, x_i), axis=0) - y = np.concatenate((y, y_i), axis=0) - return x, y - - def init_idx_mapping(idx_to_id): """ Adds dictionary item called "edge_to_index" which maps a branch/proposal - in a neurograph to an idx that represents it's position in the feature + in a FragmentsGraph to an idx that represents it's position in the feature matrix. Parameters ---------- idxs : dict - Dictionary that maps indices to edges in some neurograph. + Dictionary that maps indices to edges in a FragmentsGraph. Returns ------- diff --git a/src/deep_neurographs/utils/graph_util.py b/src/deep_neurographs/utils/graph_util.py index 7bf40c8a..fa1a2e47 100644 --- a/src/deep_neurographs/utils/graph_util.py +++ b/src/deep_neurographs/utils/graph_util.py @@ -7,11 +7,11 @@ Overview -------- -Code that reads and preprocesses neuron fragments stored as swc files, then +Code that loads and preprocesses neuron fragments stored as swc files, then constructs a custom graph object called a "FragmentsGraph" from the fragments. Graph Construction Algorithm: - 1. Read Neuron Fragments + 1. Load Neuron Fragments to do... 2. Extract Irreducibles @@ -177,7 +177,7 @@ def extract_irreducibles(self, graph): """ irreducibles = None self.prune_branches(graph) - if compute_path_length(graph) > self.min_size: + if compute_path_length(graph, self.min_size) > self.min_size: # Irreducible nodes leafs, branchings = get_irreducible_nodes(graph) @@ -355,7 +355,7 @@ def smooth_path(graph, path, xyz_list): return graph -def compute_path_length(graph): +def compute_path_length(graph, max_length=np.inf): """ Computes the path length of the given graph. @@ -374,6 +374,8 @@ def compute_path_length(graph): path_length = 0 for i, j in nx.dfs_edges(graph): path_length += compute_dist(graph, i, j) + if path_length > max_length: + break return path_length diff --git a/src/deep_neurographs/utils/img_util.py b/src/deep_neurographs/utils/img_util.py index ea0f0954..1e7c1af8 100644 --- a/src/deep_neurographs/utils/img_util.py +++ b/src/deep_neurographs/utils/img_util.py @@ -9,162 +9,228 @@ """ -from copy import deepcopy +from abc import ABC, abstractmethod +from skimage.color import label2rgb import numpy as np import tensorstore as ts -from skimage.color import label2rgb from deep_neurographs.utils import util -SUPPORTED_DRIVERS = ["neuroglancer_precomputed", "n5", "zarr"] - -# --- io utils --- -def open_tensorstore(path, driver="neuroglancer_precomputed"): +class ImageReader(ABC): """ - Opens an image that is assumed to be stored as a directory of shard files. + Abstract class to create image readers classes. - Parameters - ---------- - path : str - Path to directory containing shard files. - driver : str, optional - Storage driver needed to read data at "path". The default is - "neuroglancer_precomputed". - - Returns - ------- - tensorstore.TensorStore - Pointer to image stored at "path" in a GCS bucket. - - """ - assert driver in SUPPORTED_DRIVERS, "Driver is not supported!" - img = ts.open( - { - "driver": driver, - "kvstore": { - "driver": "gcs", - "bucket": "allen-nd-goog", - "path": path, - }, - "context": { - "cache_pool": {"total_bytes_limit": 1000000000}, - "cache_pool#remote": {"total_bytes_limit": 1000000000}, - "data_copy_concurrency": {"limit": 8}, - }, - "recheck_cached_data": "open", - } - ).result() - if driver == "neuroglancer_precomputed": - return img[ts.d["channel"][0]] - elif driver == "zarr": - img = img[0, 0, :, :, :] - img = img[ts.d[0].transpose[2]] - img = img[ts.d[0].transpose[1]] - return img - - -def read(img, voxel, shape, from_center=True): """ - Reads a chunk of data from an image given a voxel coordinate and shape. - Parameters - ---------- - img : numpy.ndarray - Image to be read. - voxel : tuple - Voxel coordinate that specifies either the center or top, left, front - corner of the chunk to be read. - shape : tuple - Shape (dimensions) of the chunk to be read. - from_center : bool, optional - Indication of whether the provided coordinates represent the center of - the chunk or the top, left, front corner. The default is True. - - Returns - ------- - numpy.ndarray - Chunk of data read from an image. - - """ - start, end = get_start_end(voxel, shape, from_center=from_center) - return deepcopy( - img[start[0]: end[0], start[1]: end[1], start[2]: end[2]] - ) - - -def read_tensorstore(img, voxel, shape, from_center=True): + def __init__(self, img_path): + """ + Class constructor of image reader. + + Parameters + ---------- + img_path : str + Path to image. + + Returns + ------- + None + + """ + self.img = None + self.img_path = img_path + self._load_image() + + @abstractmethod + def _load_image(self): + """ + This method should be implemented by subclasses to load the image + based on img_path. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + pass + + def read(self, voxel, shape, from_center=True): + """ + Reads a patch from an image given a voxel coordinate and patch shape. + + Parameters + ---------- + voxel : Tuple[int] + Voxel coordinate that is either the center or top-left-front + corner of the image patch to be read. + shape : Tuple[int] + Shape of the image patch to be read. + from_center : bool, optional + Indication of whether "voxel" is the center or top-left-front + corner of the image patch to be read. The default is True. + + Returns + ------- + ArrayLike + Image patch. + + """ + s, e = get_start_end(voxel, shape, from_center=from_center) + if len(self.shape()) == 3: + return self.img[s[0]: e[0], s[1]: e[1], s[2]: e[2]] + elif len(self.shape()) == 5: + return self.img[0, 0, s[0]: e[0], s[1]: e[1], s[2]: e[2]] + + def read_with_bbox(self, bbox): + """ + Reads an image patch by using a "bbox". + + Parameters + ---------- + bbox : dict + Dictionary that contains min and max coordinates of a bounding + box. + + Returns + ------- + numpy.ndarray + Image patch. + + """ + try: + shape = [bbox["max"][i] - bbox["min"][i] for i in range(3)] + return self.read(bbox["min"], shape, from_center=False) + except Exception: + return np.zeros(shape) + + def read_profile(self, spec): + """ + Reads an intensity profile from an image (i.e. image profile). + + Parameters + ---------- + spec : dict + Dictionary that stores the bounding box of patch to be read and the + voxel coordinates of the profile path. + + Returns + ------- + List[float] + Image profile. + + """ + img_patch = normalize(self.read_with_bbox(spec["bbox"])) + return [img_patch[voxel] for voxel in map(tuple, spec["profile_path"])] + + def shape(self): + """ + Gets the shape of image. + + Parameters + ---------- + None + + Returns + ------- + Tuple[int] + Shape of image. + + """ + return self.img.shape + + +class TensorStoreReader(ImageReader): """ - Reads a chunk from an image given a voxel coordinate and the desired shape - of the chunk. - - Parameters - ---------- - img : tensorstore.TensorStore - Image to be read. - voxel : tuple - Voxel coordinate that specifies either the center or top, left, front - corner of the chunk to be read. - shape : tuple - Shape (dimensions) of the chunk to be read. - from_center : bool, optional - Indication of whether the provided coordinates represent the center of - the chunk or the starting point. The default is True. - - Returns - ------- - numpy.ndarray - Chunk of data read from an image. + Class that reads image with tensorstore. """ - return read(img, voxel, shape, from_center=from_center).read().result() - -def read_tensorstore_with_bbox(img, bbox, normalize=True): + def __init__(self, img_path, driver): + """ + Constructs a TensorStore image reader. + + Parameters + ---------- + img_path : str + Path to image. + driver : str + Storage driver needed to read image at "path". + + Returns + ------- + None + + """ + self.driver = driver + super().__init__(img_path) + + def _load_image(self): + """ + This method should be implemented by subclasses to load the image + based on img_path. + + Parameters + ---------- + None + + Returns + ------- + None + + """ + self.img = ts.open( + { + "driver": self.driver, + "kvstore": { + "driver": "gcs", + "bucket": "allen-nd-goog", + "path": self.img_path, + }, + "context": { + "cache_pool": {"total_bytes_limit": 1000000000}, + "cache_pool#remote": {"total_bytes_limit": 1000000000}, + "data_copy_concurrency": {"limit": 8}, + }, + "recheck_cached_data": "open", + } + ).result() + if self.driver == "neuroglancer_precomputed": + self.img = self.img[ts.d["channel"][0]] + elif self.driver == "zarr": + self.img = self.img[ts.d[2].transpose[4]] + self.img = self.img[ts.d[2].transpose[3]] + + def read(self, voxel, shape, from_center=True): + img_patch = super().read(voxel, shape, from_center) + return img_patch.read().result() + + +def ZarrReader(ImageReader): """ - Reads a chunk from a subarray that is determined by "bbox". - - Parameters - ---------- - img : tensorstore.TensorStore - Image to be read. - bbox : dict - Dictionary that contains min and max coordinates of a bounding box. - - Returns - ------- - numpy.ndarray - Chunk of data read from an image. + Class that reads image with zarr. """ - try: - shape = [bbox["max"][i] - bbox["min"][i] for i in range(3)] - return read_tensorstore(img, bbox["min"], shape, from_center=False) - except Exception: - return np.zeros(shape) - -def read_profile(img, spec): - """ - Reads an intensity profile from an image (i.e. image profile). + def __init__(self, img_path): + """ + Constructs a TensorStore image reader. - Parameters - ---------- - img : tensorstore.TensorStore - Image to be read. - spec : dict - Dictionary that stores the bounding box of chunk to be read and the - voxel coordinates of the profile path. + Parameters + ---------- + img_path : str + Path to image. - Returns - ------- - numpy.ndarray - Image profile. + Returns + ------- + None - """ - img_patch = normalize(read_tensorstore_with_bbox(img, spec["bbox"])) - return [img_patch[voxel] for voxel in map(tuple, spec["profile_path"])] + """ + super().__init__(img_path) def get_start_end(voxel, shape, from_center=True): @@ -174,18 +240,18 @@ def get_start_end(voxel, shape, from_center=True): Parameters ---------- voxel : tuple - Voxel coordinate that specifies either the center or top, left, front - corner of the chunk to be read. - shape : tuple - Shape (dimensions) of the chunk to be read. + Voxel coordinate that is either the center or top-left-front corner of + the image patch to be read. + shape : Tuple[int] + Shape of the image patch to be read. from_center : bool, optional - Indication of whether the provided coordinates represent the center of - the chunk or the starting point. The default is True. + Indication of whether "voxel" is the center or top-left-front corner + of the image patch to be read. The default is True. Return ------ - tuple[list[int]] - Start and end indices of the chunk to be read. + Tuple[List[int]] + Start and end indices of the image patch to be read. """ if from_center: @@ -197,7 +263,7 @@ def get_start_end(voxel, shape, from_center=True): return start, end -# -- operations -- +# -- Operations -- def normalize(img): """ Normalizes an image so that the minimum and maximum intensity values are 0 @@ -241,7 +307,7 @@ def get_mip(img, axis=0): def get_labels_mip(img, axis=0): """ Compute the maximum intensity projection (MIP) of a segmentation along - "axis". This routine differs from "get_mip" because it retuns an rgb + "axis". This routine differs from "get_mip" because it retuns an RGB image. Parameters @@ -262,18 +328,18 @@ def get_labels_mip(img, axis=0): return (255 * mip).astype(np.uint8) -def get_profile(img, spec, profile_id): +def get_profile(img_reader, spec, profile_id): """ Gets the image profile for a given proposal. Parameters ---------- - img : tensorstore.TensorStore - Image that profiles are generated from. + img_reader : ImageReader + Image reader. spec : dict Dictionary that contains the image bounding box and coordinates of the image profile path. - profile_id : frozenset + profile_id : Frozenset[int] Identifier of profile. Returns @@ -283,13 +349,13 @@ def get_profile(img, spec, profile_id): profile. """ - profile = read_profile(img, spec) + profile = img_reader.read_profile(spec) avg, std = util.get_avg_std(profile) profile.extend([avg, std]) return {profile_id: profile} -# --- coordinate conversions --- +# --- Coordinate Conversions --- def to_physical(voxel, anisotropy, shift=[0, 0, 0]): """ Converts a voxel coordinate to a physical coordinate by applying the @@ -307,7 +373,7 @@ def to_physical(voxel, anisotropy, shift=[0, 0, 0]): Returns ------- - tuple + Tuple[float] Converted coordinates. """ @@ -316,7 +382,7 @@ def to_physical(voxel, anisotropy, shift=[0, 0, 0]): def to_voxels(xyz, anisotropy, multiscale=0): """ - Converts coordinates from world to voxel. + Converts coordinate from a physical to voxel space. Parameters ---------- @@ -332,11 +398,11 @@ def to_voxels(xyz, anisotropy, multiscale=0): Returns ------- numpy.ndarray - Voxel coordinate of the input. + Voxel coordinate. """ scaling_factor = 1.0 / 2 ** multiscale - voxel = scaling_factor * xyz / np.array(anisotropy) + voxel = [scaling_factor * xyz[i] / anisotropy[i] for i in range(3)] return np.round(voxel).astype(int)