diff --git a/.gitignore b/.gitignore index 81d47d52..7bd29b15 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,6 @@ pixi.lock # uv environments uv.lock + +# Claude +CLAUDE.md diff --git a/docs/features.md b/docs/features.md index 478e3995..7eae0c3c 100644 --- a/docs/features.md +++ b/docs/features.md @@ -150,7 +150,7 @@ classDiagram } class Tracks { - +graph: nx.DiGraph + +graph: td.graph.GraphView +segmentation: ndarray|None +features: FeatureDict +annotators: AnnotatorRegistry @@ -203,23 +203,30 @@ These features are **automatically checked** during initialization: **Scenario 1: Loading tracks from CSV with pre-computed features** ```python -# CSV has columns: id, time, y, x, area, track_id -graph = load_graph_from_csv(df) # Nodes already have area, track_id -tracks = SolutionTracks(graph, segmentation=seg) +from funtracks.import_export import tracks_from_df + +# CSV/DataFrame has columns: id, time, y, x, area, track_id, parent_id +tracks = tracks_from_df(df, segmentation=seg) # Auto-detection: pos, area, track_id exist → activate without recomputing ``` **Scenario 2: Creating tracks from raw segmentation** ```python -# Graph has no features yet -graph = nx.DiGraph() -graph.add_node(1, time=0) +from funtracks.utils import create_empty_graphview_graph +from funtracks.data_model import Tracks + +# Create empty graph and add nodes +graph = create_empty_graphview_graph() +graph.add_node(index=1, attrs={"t": 0}) tracks = Tracks(graph, segmentation=seg) -# Auto-detection: pos, area don't exist → compute them +# Auto-detection: pos, area don't exist → compute them from segmentation ``` **Scenario 3: Explicit feature control with FeatureDict** ```python +from funtracks.features import FeatureDict, Time, Position, Area +from funtracks.data_model import Tracks + # Bypass auto-detection entirely feature_dict = FeatureDict({"t": Time(), "pos": Position(), "area": Area()}) tracks = Tracks(graph, segmentation=seg, features=feature_dict) @@ -227,8 +234,9 @@ tracks = Tracks(graph, segmentation=seg, features=feature_dict) ``` **Scenario 4: Enable a new feature** - ```python +from funtracks.data_model import Tracks + tracks = Tracks(graph, segmentation) # Initially has: time, pos, area (auto-detected or computed) @@ -240,8 +248,7 @@ print(tracks.features.keys()) # All features in FeatureDict (including static) print(tracks.annotators.features.keys()) # Only active computed features ``` -**Scenario 4: Disable a feature** - +**Scenario 5: Disable a feature** ```python tracks.disable_features(["area"]) # Removes from FeatureDict, deactivates in annotators @@ -272,9 +279,9 @@ tracks.disable_features(["area"]) def compute(self, feature_keys=None): # Compute feature values in bulk if "custom" in self.features: - for node in self.tracks.graph.nodes(): + for node in self.tracks.graph.node_ids(): value = self._compute_custom(node) - self.tracks.graph.nodes[node]["custom"] = value + self.tracks[node]["custom"] = value def update(self, action): # Incremental update when graph changes diff --git a/pyproject.toml b/pyproject.toml index 3fbde313..de075cec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,8 @@ dependencies =[ "pandas>=2.3.3", "zarr>=2.18,<4", "numcodecs>=0.13,<0.16", + "tracksdata[spatial]@git+https://github.com/JoOkuma/tracksdata@9b09154c1257b6b526389f7de606e050567d9601", + # This will soon be main, and I will then update the commit hash ] [project.urls] @@ -107,6 +109,7 @@ unfixable = [ [tool.mypy] ignore_missing_imports = true python_version = "3.10" +explicit_package_bases = true [tool.coverage.report] exclude_also = [ diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index 0c678b51..f3745736 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from ._base import BasicAction @@ -10,6 +10,8 @@ from funtracks.data_model import Tracks from funtracks.data_model.tracks import Edge +import tracksdata as td + class AddEdge(BasicAction): """Action for adding a new edge. Endpoints must exist already.""" @@ -52,7 +54,24 @@ def _apply(self) -> None: f"Cannot add edge {self.edge}: endpoint {node} not in graph yet" ) - self.tracks.graph.add_edge(self.edge[0], self.edge[1], **self.attributes) + if self.tracks.graph.has_edge(*self.edge): + raise ValueError(f"Edge {self.edge} already exists in the graph") + + # Add required solution attribute + attrs = self.attributes + attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = 1 + + required_attrs = self.tracks.graph.edge_attr_keys() + for attr in required_attrs: + if attr not in attrs: + attrs[attr] = self.tracks.features[attr]["default_value"] + + # Create edge attributes for this specific edge + self.tracks.graph.add_edge( + source_id=self.edge[0], + target_id=self.edge[1], + attrs=attrs, + ) # Notify annotators to recompute features (will overwrite computed ones) self.tracks.notify_annotators(self) diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index bec4dc0d..c47a3c42 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -7,13 +7,17 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model import SolutionTracks - from funtracks.data_model.tracks import Node, SegMask + from funtracks.data_model.solution_tracks import SolutionTracks + from funtracks.data_model.tracks import Node + +import numpy as np +import tracksdata as td +from tracksdata.nodes._mask import Mask class AddNode(BasicAction): """Action for adding new nodes. If a segmentation should also be added, the - pixels for each node should be provided. The label to set the pixels will + mask for the node should be provided. The label to set the mask will be taken from the node id. The existing pixel values are assumed to be zero - you must explicitly update any other segmentations that were overwritten using an UpdateNodes action if you want to be able to undo the action. @@ -24,7 +28,7 @@ def __init__( tracks: SolutionTracks, node: Node, attributes: dict[str, Any], - pixels: SegMask | None = None, + mask: Mask | None = None, ): """Create an action to add a new node, with optional segmentation @@ -32,12 +36,12 @@ def __init__( tracks (Tracks): The Tracks to add the node to node (Node): A node id attributes (Attrs): Includes times, track_ids, and optionally positions - pixels (SegMask | None, optional): The segmentation associated with + mask (Mask | None, optional): The segmentation mask associated with the node. Defaults to None. Raises: ValueError: If time attribute is not in attributes. ValueError: If track_id is not in attributes. - ValueError: If pixels is None and position is not in attributes. + ValueError: If mask is None and position is not in attributes. """ super().__init__(tracks) self.tracks: SolutionTracks # Narrow type from base class @@ -55,7 +59,7 @@ def __init__( raise ValueError(f"Must provide a {track_id_key} attribute for node {node}") # Check for position - handle both single key and list of keys - if pixels is None: + if mask is None: if isinstance(pos_key, list): # Multi-axis position keys if not all(key in attributes for key in pos_key): @@ -68,7 +72,7 @@ def __init__( raise ValueError( f"Must provide position or segmentation for node {node}" ) - self.pixels = pixels + self.mask = mask self.attributes = attributes self._apply() @@ -77,31 +81,48 @@ def inverse(self) -> BasicAction: return DeleteNode(self.tracks, self.node) def _apply(self) -> None: - """Apply the action, and set segmentation if provided in self.pixels""" - if self.pixels is not None: - self.tracks.set_pixels(self.pixels, self.node) + """Apply the action, and set segmentation if provided in self.mask""" attrs = self.attributes - self.tracks.graph.add_node(self.node) - # set all user provided attributes including time and position - for attr, value in attrs.items(): - self.tracks._set_node_attr(self.node, attr, value) + if self.tracks.segmentation is not None: + if self.mask is not None: + attrs[td.DEFAULT_ATTR_KEYS.MASK] = self.mask + attrs[td.DEFAULT_ATTR_KEYS.BBOX] = self.mask.bbox + else: + # TODO Teun: remove this defaulting behavior, see new tracksdata PR + # update: default behaviour in td has a bug rn, will remove later + if len(self.tracks.segmentation.shape) == 3: + attrs[td.DEFAULT_ATTR_KEYS.MASK] = Mask( + np.array([[False]]), bbox=[0, 0, 1, 1] + ) + attrs[td.DEFAULT_ATTR_KEYS.BBOX] = [0, 0, 1, 1] + elif len(self.tracks.segmentation.shape) == 4: + attrs[td.DEFAULT_ATTR_KEYS.MASK] = Mask( + np.array([[[False]]]), bbox=[0, 0, 0, 1, 1, 1] + ) + attrs[td.DEFAULT_ATTR_KEYS.BBOX] = [0, 0, 0, 1, 1, 1] + else: + raise ValueError( + "Must provide mask when adding node to tracks with seg" + ) + + self.tracks.graph.add_node(attrs=attrs, index=self.node, validate_keys=False) # Always notify annotators - they will check their own preconditions self.tracks.notify_annotators(self) class DeleteNode(BasicAction): - """Action of deleting existing nodes + """Action of deleting existing node If the tracks contain a segmentation, this action also constructs a reversible - operation for setting involved pixels to zero + operation for setting involved masks to zero """ def __init__( self, tracks: SolutionTracks, node: Node, - pixels: SegMask | None = None, + mask: Mask | None = None, ): super().__init__(tracks) self.tracks: SolutionTracks # Narrow type from base class @@ -114,24 +135,34 @@ def __init__( if val is not None: self.attributes[key] = val - self.pixels = self.tracks.get_pixels(node) if pixels is None else pixels + if td.DEFAULT_ATTR_KEYS.MASK in self.tracks.graph.node_attr_keys(): + self.attributes[td.DEFAULT_ATTR_KEYS.MASK] = self.tracks.get_nodes_attr( + [self.node], td.DEFAULT_ATTR_KEYS.MASK + )[0] + self.attributes[td.DEFAULT_ATTR_KEYS.BBOX] = self.tracks.get_nodes_attr( + [self.node], td.DEFAULT_ATTR_KEYS.BBOX + )[0] + self.attributes[td.DEFAULT_ATTR_KEYS.SOLUTION] = self.tracks.get_nodes_attr( + [self.node], td.DEFAULT_ATTR_KEYS.SOLUTION + )[0] + + mask = self.tracks.get_mask(node) if mask is None else mask + + self.mask = mask self._apply() def inverse(self) -> BasicAction: """Invert this action, and provide inverse segmentation operation if given""" - return AddNode(self.tracks, self.node, self.attributes, pixels=self.pixels) + return AddNode(self.tracks, self.node, self.attributes, mask=self.mask) def _apply(self) -> None: """ASSUMES THERE ARE NO INCIDENT EDGES - raises valueerror if an edge will be removed by this operation Steps: - - For each node - set pixels to 0 if self.pixels is provided - Remove nodes from graph + - Update annotators """ - if self.pixels is not None: - self.tracks.set_pixels(self.pixels, 0) self.tracks.graph.remove_node(self.node) self.tracks.notify_annotators(self) diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py index ce275bd7..2eef9ef0 100644 --- a/src/funtracks/actions/update_segmentation.py +++ b/src/funtracks/actions/update_segmentation.py @@ -2,37 +2,40 @@ from typing import TYPE_CHECKING +import tracksdata as td +from tracksdata.nodes._mask import Mask + from ._base import BasicAction if TYPE_CHECKING: from funtracks.data_model import Tracks - from funtracks.data_model.tracks import Node, SegMask + from funtracks.data_model.tracks import Node class UpdateNodeSeg(BasicAction): """Action for updating the segmentation associated with a node. - New nodes call AddNode with pixels instead of this action. + New nodes call AddNode with mask instead of this action. """ def __init__( self, tracks: Tracks, node: Node, - pixels: SegMask, + mask: Mask, added: bool = True, ): """ Args: - tracks (Tracks): The tracks to update the segmenatations for - node (Node): The node with updated segmenatation - pixels (SegMask): The pixels that were updated for the node - added (bool, optional): If the provided pixels were added (True) or deleted + tracks (Tracks): The tracks to update the segmentations for + node (Node): The node with updated segmentation + mask (Mask): The mask that was updated for the node + added (bool, optional): If the provided mask were added (True) or deleted (False) from this node. Defaults to True """ super().__init__(tracks) self.node = node - self.pixels = pixels + self.mask = mask self.added = added self._apply() @@ -41,12 +44,42 @@ def inverse(self) -> BasicAction: return UpdateNodeSeg( self.tracks, self.node, - pixels=self.pixels, + mask=self.mask, added=not self.added, ) def _apply(self) -> None: """Set new attributes""" value = self.node if self.added else 0 - self.tracks.set_pixels(self.pixels, value) + + mask_new = self.mask + + if value == 0: + # val=0 means deleting (part of) the mask + mask_old = self.tracks.graph[self.node][td.DEFAULT_ATTR_KEYS.MASK] + mask_subtracted = mask_old.__isub__(mask_new) + self.tracks.graph.update_node_attrs( + attrs={ + td.DEFAULT_ATTR_KEYS.MASK: [mask_subtracted], + td.DEFAULT_ATTR_KEYS.BBOX: [mask_subtracted.bbox], + }, + node_ids=[self.node], + ) + + elif self.tracks.graph.has_node(value): + # if node already exists: + mask_old = self.tracks.graph[value][td.DEFAULT_ATTR_KEYS.MASK] + mask_combined = mask_old.__or__(mask_new) + self.tracks.graph.update_node_attrs( + attrs={ + td.DEFAULT_ATTR_KEYS.MASK: [mask_combined], + td.DEFAULT_ATTR_KEYS.BBOX: [mask_combined.bbox], + }, + node_ids=[value], + ) + + # Invalidate cache for affected chunks + time = self.tracks.get_time(self.node) + self.tracks._update_segmentation_cache(mask=mask_new, time=time) + self.tracks.notify_annotators(self) diff --git a/src/funtracks/annotators/_compute_ious.py b/src/funtracks/annotators/_compute_ious.py index 7f848c2f..0ffb10f0 100644 --- a/src/funtracks/annotators/_compute_ious.py +++ b/src/funtracks/annotators/_compute_ious.py @@ -1,35 +1,15 @@ -import numpy as np +from tracksdata.nodes._mask import Mask -def _compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> list[tuple[int, int, float]]: - """Compute label IOUs between two label arrays of the same shape. Ignores background - (label 0). +def _compute_iou(mask1: Mask, mask2: Mask) -> list[tuple[int, int, float]]: + """Compute label IOUs between two Mask objects. Args: - frame1 (np.ndarray): Array with integer labels - frame2 (np.ndarray): Array with integer labels + mask1 (Mask): First mask object + mask2 (Mask): Second mask object Returns: - list[tuple[int, int, float]]: List of tuples of label in frame 1, label in - frame 2, and iou values. Labels that have no overlap are not included. + iou (int): IOU value between the two masks """ - frame1 = frame1.flatten() - frame2 = frame2.flatten() - # get indices where both are not zero (ignore background) - # this speeds up computation significantly - non_zero_indices = np.logical_and(frame1, frame2) - flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]]) - - values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) - frame1_values, frame1_counts = np.unique(frame1, return_counts=True) - frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=True)) - frame2_values, frame2_counts = np.unique(frame2, return_counts=True) - frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=True)) - ious: list[tuple[int, int, float]] = [] - for index in range(values.shape[1]): - pair = values[:, index] - intersection = counts[index] - id1, id2 = pair - union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection - ious.append((id1, id2, intersection / union)) - return ious + iou = mask1.iou(mask2) + return iou diff --git a/src/funtracks/annotators/_edge_annotator.py b/src/funtracks/annotators/_edge_annotator.py index fdd241e5..8082a869 100644 --- a/src/funtracks/annotators/_edge_annotator.py +++ b/src/funtracks/annotators/_edge_annotator.py @@ -4,13 +4,11 @@ from collections import defaultdict from typing import TYPE_CHECKING -import numpy as np - from funtracks.actions.add_delete_edge import AddEdge from funtracks.actions.update_segmentation import UpdateNodeSeg from funtracks.features import Feature, IoU -from ._compute_ious import _compute_ious +from ._compute_ious import _compute_iou from ._graph_annotator import GraphAnnotator if TYPE_CHECKING: @@ -82,44 +80,36 @@ def compute(self, feature_keys: list[str] | None = None) -> None: if not keys_to_compute: return - seg = self.tracks.segmentation # TODO: add skip edges if self.iou_key in keys_to_compute: nodes_by_frame = defaultdict(list) - for n in self.tracks.nodes(): + for n in self.tracks.graph.node_ids(): nodes_by_frame[self.tracks.get_time(n)].append(n) - for t in range(seg.shape[0] - 1): + for t in range(self.tracks.segmentation.shape[0] - 1): nodes_in_t = nodes_by_frame[t] - edges = list(self.tracks.graph.out_edges(nodes_in_t)) - self._iou_update(edges, seg[t], seg[t + 1]) + edges = [] + for node in nodes_in_t: + for succ in self.tracks.graph.successors(node): + edges.append((node, succ)) + self._iou_update(edges) def _iou_update( self, edges: list[tuple[int, int]], - seg_frame: np.ndarray, - seg_next_frame: np.ndarray, ) -> None: """Perform the IoU computation and update all feature values for a - single pair of frames of segmentation data. + list of edges. Args: edges (list[tuple[int, int]]): A list of edges between two frames - seg_frame (np.ndarray): A 2D or 3D numpy array representing the seg for the - starting time of the edges - seg_next_frame (np.ndarray): A 2D or 3D numpy array representing the seg for - the ending time of the edges """ - ious = _compute_ious(seg_frame, seg_next_frame) # list of (id1, id2, iou) - for id1, id2, iou in ious: - edge = (id1, id2) - if edge in edges: - self.tracks._set_edge_attr(edge, self.iou_key, iou) - edges.remove(edge) - - # anything left has IOU of 0 for edge in edges: - self.tracks._set_edge_attr(edge, self.iou_key, 0) + source, target = edge + mask1 = self.tracks.graph[source]["mask"] + mask2 = self.tracks.graph[target]["mask"] + iou = _compute_iou(mask1, mask2) + self.tracks._set_edge_attr(edge, self.iou_key, iou) def update(self, action: BasicAction): """Update the edge features based on the action. @@ -146,29 +136,30 @@ def update(self, action: BasicAction): else: # UpdateNodeSeg # Get all incident edges to the modified node node = action.node - edges_to_update = list(self.tracks.graph.in_edges(node)) + list( - self.tracks.graph.out_edges(node) - ) + + edges_to_update = [] + for node in self.tracks.graph.node_ids(): + # Add edges from predecessors + for pred in self.tracks.graph.predecessors(node): + edges_to_update.append((pred, node)) + # Add edges from successors + for succ in self.tracks.graph.successors(node): + edges_to_update.append((node, succ)) # Update IoU for each edge for edge in edges_to_update: source, target = edge - start_time = self.tracks.get_time(source) - end_time = self.tracks.get_time(target) - start_seg = self.tracks.segmentation[start_time] - end_seg = self.tracks.segmentation[end_time] - masked_start = np.where(start_seg == source, source, 0) - masked_end = np.where(end_seg == target, target, 0) - if np.max(masked_start) == 0 or np.max(masked_end) == 0: + mask1 = self.tracks.graph[source]["mask"] + mask2 = self.tracks.graph[target]["mask"] + if mask1.mask.sum() == 0 or mask2.mask.sum() == 0: warnings.warn( - f"Cannot find label {source} in frame {start_time} or label {target} " - f"in frame {end_time}: updating edge IOU value to 0", + f"Cannot find label {source} in segmentation" + f": updating edge IOU value to 0", stacklevel=2, ) - self.tracks._set_edge_attr(edge, self.iou_key, 0) + self.tracks._set_edge_attr(edge, self.iou_key, 0.0) else: - iou_list = _compute_ious(masked_start, masked_end) - iou = 0 if len(iou_list) == 0 else iou_list[0][2] + iou = _compute_iou(mask1, mask2) self.tracks._set_edge_attr(edge, self.iou_key, iou) def change_key(self, old_key: str, new_key: str) -> None: diff --git a/src/funtracks/annotators/_graph_annotator.py b/src/funtracks/annotators/_graph_annotator.py index 7fb8583c..8d3c0dd2 100644 --- a/src/funtracks/annotators/_graph_annotator.py +++ b/src/funtracks/annotators/_graph_annotator.py @@ -3,6 +3,10 @@ import logging from typing import TYPE_CHECKING +import polars as pl + +from funtracks.utils.tracksdata_utils import to_polars_dtype + if TYPE_CHECKING: from funtracks.actions import BasicAction from funtracks.data_model import Tracks @@ -66,6 +70,31 @@ def activate_features(self, keys: list[str]) -> None: feat, _ = self.all_features[key] self.all_features[key] = (feat, True) + # Ensure attribute key exists in graph schema + if ( + feat["feature_type"] == "node" + and key not in self.tracks.graph.node_attr_keys() + ): + # Get the dtype from the feature dict + # unless the feature has multiple values, in which case use Array type + dtype = to_polars_dtype(feat["value_type"]) + if feat["num_values"] is not None and feat["num_values"] > 1: + dtype = pl.Array(pl.Float64, feat["num_values"]) + self.tracks.graph.add_node_attr_key( + key, + default_value=feat["default_value"], + dtype=dtype, + ) + elif ( + feat["feature_type"] == "edge" + and key not in self.tracks.graph.edge_attr_keys() + ): + self.tracks.graph.add_edge_attr_key( + key, + default_value=feat["default_value"], + dtype=to_polars_dtype(feat["value_type"]), + ) + def deactivate_features(self, keys: list[str]) -> None: """Deactivate computation of the given features in the annotation process. diff --git a/src/funtracks/annotators/_regionprops_annotator.py b/src/funtracks/annotators/_regionprops_annotator.py index 30a5759b..0c103a1b 100644 --- a/src/funtracks/annotators/_regionprops_annotator.py +++ b/src/funtracks/annotators/_regionprops_annotator.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, NamedTuple import numpy as np +from tracksdata.nodes._mask import Mask from funtracks.actions.add_delete_node import AddNode from funtracks.actions.update_segmentation import UpdateNodeSeg @@ -165,30 +166,37 @@ def compute(self, feature_keys: list[str] | None = None) -> None: if not keys_to_compute: return - seg = self.tracks.segmentation - for t in range(seg.shape[0]): - self._regionprops_update(seg[t], keys_to_compute) + for node_id in self.tracks.graph.node_ids(): + mask = self.tracks.graph[node_id]["mask"] + self._regionprops_update(node_id, mask, keys_to_compute) - def _regionprops_update(self, seg_frame: np.ndarray, feature_keys: list[str]) -> None: + def _regionprops_update( + self, node_id: int, mask: Mask, feature_keys: list[str] + ) -> None: """Perform the regionprops computation and update all feature values for a - single frame of segmentation data. + single mask. Args: - seg_frame (np.ndarray): A 2D or 3D numpy array representing one time point + node_id (int): The node ID to update features for. + mask (Mask): A Mask object representing one time point of segmentation data. - feature_keys: List of feature keys to compute (already filtered to enabled). + feature_keys (list): List of feature keys to compute + (already filtered to enabled). """ spacing = None if self.tracks.scale is None else tuple(self.tracks.scale[1:]) - for region in regionprops_extended(seg_frame, spacing=spacing): - node = region.label + for region in regionprops_extended(mask, spacing=spacing): # Skip labels that aren't nodes in the graph (e.g., unselected detections) - if node not in self.tracks.graph: + if not self.tracks.graph.has_node(node_id): continue for key in feature_keys: value = getattr(region, self.regionprops_names[key]) if isinstance(value, tuple): - value = list(value) - self.tracks._set_node_attr(node, key, value) + value = [ + float(v) for v in value + ] # cannot be a list of np.arrays with single values + elif isinstance(value, np.floating): + value = float(value) + self.tracks._set_node_attr(node_id, key, value) def update(self, action: BasicAction): """Update the regionprops features based on the action. @@ -214,10 +222,8 @@ def update(self, action: BasicAction): return time = self.tracks.get_time(node) - seg_frame = self.tracks.segmentation[time] - masked_frame = np.where(seg_frame == node, node, 0) - if np.max(masked_frame) == 0: + if self.tracks.graph[node]["mask"].mask.sum() == 0: warnings.warn( f"Cannot find label {node} in frame {time}: " "updating regionprops values to None", @@ -227,7 +233,8 @@ def update(self, action: BasicAction): value = None self.tracks._set_node_attr(node, key, value) else: - self._regionprops_update(masked_frame, keys_to_compute) + mask = self.tracks.graph[node]["mask"] + self._regionprops_update(node, mask, keys_to_compute) def change_key(self, old_key: str, new_key: str) -> None: """Rename a feature key in this annotator, and related mappings. diff --git a/src/funtracks/annotators/_regionprops_extended.py b/src/funtracks/annotators/_regionprops_extended.py index 3f4fa473..9a1a780d 100644 --- a/src/funtracks/annotators/_regionprops_extended.py +++ b/src/funtracks/annotators/_regionprops_extended.py @@ -1,8 +1,9 @@ import math import numpy as np -from skimage.measure import marching_cubes, mesh_surface_area, regionprops +from skimage.measure import marching_cubes, mesh_surface_area from skimage.measure._regionprops import RegionProperties +from tracksdata.nodes._mask import Mask class ExtendedRegionProperties(RegionProperties): @@ -138,8 +139,15 @@ def perimeter(self): if self._label_image.ndim == 2: return super().perimeter else: # 3D + # Create binary mask and pad with background to ensure a surface boundary + # exists. This prevents marching_cubes from failing when the mask fills + # the entire volume + binary_mask = self._label_image == self.label + padded_mask = np.pad( + binary_mask, pad_width=1, mode="constant", constant_values=False + ) verts, faces, _, _ = marching_cubes( - self._label_image == self.label, level=0.5, spacing=self._spacing + padded_mask, level=0.5, spacing=self._spacing ) return mesh_surface_area(verts, faces) @@ -170,34 +178,32 @@ def voxel_count(self): def regionprops_extended( - img: np.ndarray, + mask: Mask, spacing: tuple[float, ...] | None, - intensity_image: np.ndarray | None = None, ) -> list[ExtendedRegionProperties]: """ Create instances of ExtendedRegionProperties that extend skimage.measure.RegionProperties. Args: - img (np.ndarray): The labeled image. + mask (Mask): The labeled mask. spacing (tuple[float, ...]| None): The spacing between voxels in each dimension. If None, each voxel is assumed to be 1 in all dimensions. - intensity_image (np.ndarray, optional): The intensity image. Returns: list[ExtendedRegionProperties]: A list of ExtendedRegionProperties instances. """ - results = regionprops(img, intensity_image=intensity_image, spacing=spacing) - for i, _ in enumerate(results): - a = results[i] - b = ExtendedRegionProperties( - slice=a.slice, - label=a.label, - label_image=a._label_image, - intensity_image=a._intensity_image, - cache_active=a._cache_active, - spacing=a._spacing, - ) - results[i] = b - - return results + + region = mask.regionprops(spacing=spacing) + + extended_region = ExtendedRegionProperties( + slice=region.slice, + label=region.label, + label_image=region._label_image, + intensity_image=region._intensity_image, + cache_active=region._cache_active, + spacing=region._spacing, + offset=region._offset, + ) + + return [extended_region] diff --git a/src/funtracks/annotators/_track_annotator.py b/src/funtracks/annotators/_track_annotator.py index 8ce639bb..9af4dd31 100644 --- a/src/funtracks/annotators/_track_annotator.py +++ b/src/funtracks/annotators/_track_annotator.py @@ -3,7 +3,7 @@ from collections import defaultdict from typing import TYPE_CHECKING -import networkx as nx +import rustworkx as rx from funtracks.actions import AddNode, DeleteNode, UpdateTrackID from funtracks.data_model import SolutionTracks @@ -110,13 +110,13 @@ def __init__( self.max_lineage_id = 0 # Initialize tracklet bookkeeping if track IDs already exist in the graph - if tracks.graph.number_of_nodes() > 0: + if tracks.graph.num_nodes() > 0: max_id, id_to_nodes = self._get_max_id_and_map(self.tracklet_key) self.max_tracklet_id = max_id self.tracklet_id_to_nodes = id_to_nodes # Initialize lineage bookkeeping if lineage IDs already exist - if lineage_key is not None and tracks.graph.number_of_nodes() > 0: + if lineage_key is not None and tracks.graph.num_nodes() > 0: max_id, id_to_nodes = self._get_max_id_and_map(self.lineage_key) self.max_lineage_id = max_id self.lineage_id_to_nodes = id_to_nodes @@ -133,7 +133,9 @@ def _get_max_id_and_map(self, key: str) -> tuple[int, dict[int, list[int]]]: """ id_to_nodes = defaultdict(list) max_id = 0 - for node in self.tracks.nodes(): + for node in self.tracks.graph.node_ids(): + if key not in self.tracks.graph.node_attr_keys(): + continue _id: int = self.tracks.get_node_attr(node, key) if _id is None: continue @@ -195,8 +197,16 @@ def _assign_lineage_ids(self) -> None: Each connected component will get a unique id, and the relevant class attributes will be updated. """ - lineages = nx.weakly_connected_components(self.tracks.graph) - max_id, ids_to_nodes = self._assign_ids(lineages, self.lineage_key) + lineages_internal = rx.weakly_connected_components(self.tracks.graph.rx_graph) + lineages_external = [] + for lin in lineages_internal: + node_ids_internal = list(lin) + node_ids_external = [ + self.tracks.graph.node_ids()[nid] for nid in node_ids_internal + ] + lineages_external.append(node_ids_external) + + max_id, ids_to_nodes = self._assign_ids(lineages_external, self.lineage_key) self.max_lineage_id = max_id self.lineage_id_to_nodes = ids_to_nodes @@ -206,19 +216,37 @@ def _assign_tracklet_ids(self) -> None: After removing division edges, each connected component will get a unique ID, and the relevant class attributes will be updated. """ - graph_copy = self.tracks.graph.copy() - parents = [node for node, degree in self.tracks.graph.out_degree() if degree >= 2] + graph_copy = self.tracks.graph.detach().filter().subgraph() + + parents = [ + node + for node, degree in zip( + self.tracks.graph.node_ids(), self.tracks.graph.out_degree(), strict=True + ) + if degree >= 2 + ] # Remove all intertrack edges from a copy of the original graph for parent in parents: - daughters = self.tracks.successors(parent) + all_edges = self.tracks.graph.edge_list() + daughters = [edge[1] for edge in all_edges if edge[0] == parent] + for daughter in daughters: graph_copy.remove_edge(parent, daughter) - tracklets = nx.weakly_connected_components(graph_copy) - max_id, ids_to_nodes = self._assign_ids(tracklets, self.tracklet_key) - self.max_tracklet_id = max_id - self.tracklet_id_to_nodes = ids_to_nodes + track_id = 1 + for tracklet in rx.weakly_connected_components(graph_copy.rx_graph): + node_ids_internal = list(tracklet) + node_ids_external = [graph_copy.node_ids()[nid] for nid in node_ids_internal] + self.tracks.graph.update_node_attrs( + attrs={ + self.tracks.features.tracklet_key: [track_id] * len(node_ids_external) + }, + node_ids=node_ids_external, + ) + self.tracklet_id_to_nodes[track_id] = node_ids_external + track_id += 1 + self.max_tracklet_id = track_id - 1 def update(self, action: BasicAction) -> None: """Update track-level features based on the action. diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 3c186af8..cce8779e 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -2,8 +2,7 @@ from typing import TYPE_CHECKING -import networkx as nx -import numpy as np +import tracksdata as td from funtracks.features import FeatureDict @@ -20,14 +19,14 @@ class SolutionTracks(Tracks): def __init__( self, - graph: nx.DiGraph, - segmentation: np.ndarray | None = None, + graph: td.graph.GraphView, time_attr: str | None = None, pos_attr: str | tuple[str] | list[str] | None = None, tracklet_attr: str | None = None, scale: list[float] | None = None, ndim: int | None = None, features: FeatureDict | None = None, + _segmentation: td.array.GraphArrayView | None = None, ): """Initialize a SolutionTracks object. @@ -35,10 +34,8 @@ def __init__( TrackAnnotator is automatically added to manage track IDs. Args: - graph (nx.DiGraph): NetworkX directed graph with nodes as detections and - edges as links. - segmentation (np.ndarray | None): Optional segmentation array where labels - match node IDs. Required for computing region properties (area, etc.). + graph (td.graph.GraphView): NetworkX directed graph with nodes as detections + and edges as links. time_attr (str | None): Graph attribute name for time. Defaults to "time" if None. pos_attr (str | tuple[str, ...] | list[str] | None): Graph attribute @@ -57,16 +54,18 @@ def __init__( Assumes that all features in the dict already exist on the graph (will be activated but not recomputed). If None, core computed features (pos, area, track_id) are auto-detected by checking if they exist on the graph. + _segmentation (GraphArrayView | None): Internal parameter for reusing an + existing GraphArrayView instance. Not intended for public use. """ super().__init__( graph, - segmentation=segmentation, time_attr=time_attr, pos_attr=pos_attr, tracklet_attr=tracklet_attr, scale=scale, ndim=ndim, features=features, + _segmentation=_segmentation, ) self.track_annotator = self._get_track_annotator() @@ -92,19 +91,25 @@ def _get_track_annotator(self) -> TrackAnnotator: @classmethod def from_tracks(cls, tracks: Tracks): force_recompute = False - if (tracklet_key := tracks.features.tracklet_key) is not None: - # Check if all nodes have track_id before trusting existing track IDs - # Short circuit on first missing track_id - for node in tracks.graph.nodes(): - if tracks.get_node_attr(node, tracklet_key) is None: - force_recompute = True - break + # Check if all nodes have track_id before trusting existing track IDs + if ( + tracks.features.tracklet_key is not None + and ( + tracks.graph.node_attrs(attr_keys=tracks.features.tracklet_key)[ + tracks.features.tracklet_key + ] + == -1 + ).any() + # Attributes are no longer None, so 0 now means non-computed + ): + force_recompute = True + soln_tracks = cls( tracks.graph, - segmentation=tracks.segmentation, scale=tracks.scale, ndim=tracks.ndim, features=tracks.features, + _segmentation=tracks.segmentation, ) if force_recompute: soln_tracks.enable_features([soln_tracks.features.tracklet_key]) # type: ignore @@ -168,7 +173,10 @@ def get_track_neighbors( elif self.get_time(cand) > time: succ = cand break - return pred, succ + return ( + int(pred) if pred is not None else None, + int(succ) if succ is not None else None, + ) def has_track_id_at_time(self, track_id: int, time: int) -> bool: """Function to check if a node with given track id exists at given time point. diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 1ae44b4e..7260f328 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools import logging from collections.abc import Iterable, Sequence from typing import ( @@ -9,13 +10,21 @@ ) from warnings import warn -import networkx as nx import numpy as np +import tracksdata as td from psygnal import Signal +from tracksdata.array import GraphArrayView +from tracksdata.nodes._mask import Mask from funtracks.features import Feature, FeatureDict, Position, Time +from funtracks.utils.tracksdata_utils import ( + td_get_single_attr_from_edge, + to_polars_dtype, +) if TYPE_CHECKING: + import tracksdata as td + from funtracks.actions import BasicAction from funtracks.annotators import AnnotatorRegistry, GraphAnnotator @@ -36,11 +45,8 @@ class Tracks: position attribute. Edges in the graph represent links across time. Attributes: - graph (nx.DiGraph): A graph with nodes representing detections and + graph (td.graph.GraphView): A graph with nodes representing detections and and edges representing links across time. - segmentation (np.ndarray | None): An optional segmentation that - accompanies the tracking graph. If a segmentation is provided, - the node ids in the graph must match the segmentation labels. features (FeatureDict): Dictionary of features tracked on graph nodes/edges. annotators (AnnotatorRegistry): List of annotators that compute features. scale (list[float] | None): How much to scale each dimension by, including time. @@ -51,22 +57,20 @@ class Tracks: def __init__( self, - graph: nx.DiGraph, - segmentation: np.ndarray | None = None, + graph: td.graph.GraphView, time_attr: str | None = None, pos_attr: str | tuple[str, ...] | list[str] | None = None, tracklet_attr: str | None = None, scale: list[float] | None = None, ndim: int | None = None, features: FeatureDict | None = None, + _segmentation: GraphArrayView | None = None, ): """Initialize a Tracks object. Args: - graph (nx.DiGraph): NetworkX directed graph with nodes as detections and - edges as links. - segmentation (np.ndarray | None): Optional segmentation array where labels - match node IDs. Required for computing region properties (area, etc.). + graph (td.graph.GraphView): NetworkX directed graph with nodes as detections + and edges as links. time_attr (str | None): Graph attribute name for time. Defaults to "time" if None. pos_attr (str | tuple[str, ...] | list[str] | None): Graph attribute @@ -85,11 +89,36 @@ def __init__( Assumes that all features in the dict already exist on the graph (will be activated but not recomputed). If None, core computed features (pos, area, track_id) are auto-detected by checking if they exist on the graph. + _segmentation (GraphArrayView | None): Internal parameter for reusing an + existing GraphArrayView instance. Not intended for public use. """ self.graph = graph - self.segmentation = segmentation + if _segmentation is not None: + # Reuse provided segmentation instance (internal use only) + self.segmentation = _segmentation + elif "mask" in graph.node_attr_keys(): + # Create new GraphArrayView from graph metadata + try: + array_view = GraphArrayView( + graph=graph, + shape=graph.metadata()["segmentation_shape"], + attr_key="node_id", + offset=0, + ) + self.segmentation = array_view + except (ValueError, KeyError) as err: + raise ValueError( + "segmentation_shape is incompatible with graph, " + "check if mask and bbox attrs exist on nodes" + ) from err + else: + self.segmentation = None self.scale = scale - self.ndim = self._compute_ndim(segmentation, scale, ndim) + self.ndim = self._compute_ndim( + self.segmentation.shape if self.segmentation is not None else None, + scale, + ndim, + ) self.axis_names = ["z", "y", "x"] if self.ndim == 4 else ["y", "x"] # initialization steps: @@ -254,12 +283,11 @@ def _check_existing_feature(self, key: str) -> bool: bool: True if the key is on the first sampled node or there are no nodes, and False if missing from the first node. """ - if self.graph.number_of_nodes() == 0: + if self.graph.num_nodes() == 0: return True - # Get a sample node to check which attributes exist - sample_node = next(iter(self.graph.nodes())) - node_attrs = set(self.graph.nodes[sample_node].keys()) + # Check which attributes exist + node_attrs = set(self.graph.node_attr_keys()) return key in node_attrs def _setup_core_computed_features(self) -> None: @@ -291,26 +319,35 @@ def _setup_core_computed_features(self) -> None: # Add to FeatureDict if not already there if key not in self.features: feature, _ = self.annotators.all_features[key] - self.features[key] = feature + self.add_feature(key, feature) self.annotators.activate_features([key]) else: # enable it (compute it) self.enable_features([key]) def nodes(self): - return np.array(self.graph.nodes()) + return np.array(self.graph.node_ids()) def edges(self): - return np.array(self.graph.edges()) + return np.array(self.graph.edge_ids()) def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: + """Get the in-degree edge_ids of the nodes in the graph.""" if nodes is not None: + # make sure nodes is a numpy array + if not isinstance(nodes, np.ndarray): + nodes = np.array(nodes) + return np.array([self.graph.in_degree(node.item()) for node in nodes]) else: return np.array(self.graph.in_degree()) def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: if nodes is not None: + # make sure nodes is a numpy array + if not isinstance(nodes, np.ndarray): + nodes = np.array(nodes) + return np.array([self.graph.out_degree(node.item()) for node in nodes]) else: return np.array(self.graph.out_degree()) @@ -388,7 +425,7 @@ def set_positions( for idx, key in enumerate(self.features.position_key): self._set_nodes_attr(nodes, key, positions[:, idx].tolist()) else: - self._set_nodes_attr(nodes, self.features.position_key, positions.tolist()) + self._set_nodes_attr(nodes, self.features.position_key, positions) def set_position(self, node: Node, position: list | np.ndarray): self.set_positions([node], np.expand_dims(np.array(position), axis=0)) @@ -408,6 +445,22 @@ def get_time(self, node: Node) -> int: """ return int(self.get_times([node])[0]) + def get_mask(self, node: Node) -> Mask | None: + """Get the segmentation mask associated with a given node. + + Args: + node (Node): The node to get the mask for. + + Returns: + Mask | None: The segmentation mask for the node, or None if no + segmentation is available. + """ + if self.segmentation is None: + return None + + mask = self.graph[node][td.DEFAULT_ATTR_KEYS.MASK] + return mask + def get_pixels(self, node: Node) -> tuple[np.ndarray, ...] | None: """Get the pixels corresponding to each node in the nodes list. @@ -422,32 +475,73 @@ def get_pixels(self, node: Node) -> tuple[np.ndarray, ...] | None: """ if self.segmentation is None: return None + + # Get time and mask for the node time = self.get_time(node) - loc_pixels = np.nonzero(self.segmentation[time] == node) - time_array = np.ones_like(loc_pixels[0]) * time - return (time_array, *loc_pixels) + mask = self.graph[node][td.DEFAULT_ATTR_KEYS.MASK] + + # Get local coordinates and convert to global using bbox offset + local_coords = np.nonzero(mask.mask) + global_coords = [coord + mask.bbox[dim] for dim, coord in enumerate(local_coords)] + + # Create time array matching the number of points + time_array = np.full_like(global_coords[0], time) - def set_pixels(self, pixels: tuple[np.ndarray, ...], value: int) -> None: - """Set the given pixels in the segmentation to the given value. + return (time_array, *global_coords) + + def _update_segmentation_cache(self, mask: td.Mask, time: int) -> None: + """Invalidate cached chunks that overlap with the given mask. Args: - pixels (Iterable[tuple[np.ndarray]]): The pixels that should be set, - formatted like the output of np.nonzero (each element of the tuple - represents one dimension, containing an array of indices in that - dimension). Can be used to directly index the segmentation. - value (Iterable[int | None]): The value to set each pixel to + mask: Mask object with .bbox attribute defining the affected region + time: Time point of the mask """ if self.segmentation is None: - raise ValueError("Cannot set pixels when segmentation is None") - self.segmentation[pixels] = value + return + + cache = self.segmentation._cache + + # Only invalidate if this time point is in the cache + if time not in cache._store: + return + + # Convert bbox to slices directly + # bbox format: [z_min, y_min, x_min, z_max, y_max, x_max] (3D) + # or [y_min, x_min, y_max, x_max] (2D) + ndim = len(mask.bbox) // 2 + volume_slicing = tuple( + slice(mask.bbox[i], mask.bbox[i + ndim] + 1) for i in range(ndim) + ) + + # Use cache's method to get chunk bounds (same logic as cache.get()) + bounds = cache._chunk_bounds(volume_slicing) + chunk_ranges = [range(lo, hi + 1) for lo, hi in bounds] + + # Invalidate all affected chunks + cache_entry = cache._store[time] + for chunk_idx in itertools.product(*chunk_ranges): + if all( + 0 <= idx < grid_size + for idx, grid_size in zip(chunk_idx, cache.grid_shape, strict=True) + ): + cache_entry.ready[chunk_idx] = False + # Clear the buffer to ensure stale data isn't used + # when the chunk is recomputed + chunk_slc = tuple( + slice(ci * cs, min((ci + 1) * cs, fs)) + for ci, cs, fs in zip( + chunk_idx, cache.chunk_shape, cache.shape, strict=True + ) + ) + cache_entry.buffer[chunk_slc] = 0 def _compute_ndim( self, - seg: np.ndarray | None, + segmentation_shape: tuple[int, ...] | None, scale: list[float] | None, provided_ndim: int | None, ): - seg_ndim = seg.ndim if seg is not None else None + seg_ndim = len(segmentation_shape) if segmentation_shape is not None else None scale_ndim = len(scale) if scale is not None else None ndims = [seg_ndim, scale_ndim, provided_ndim] ndims = [d for d in ndims if d is not None] @@ -467,35 +561,33 @@ def _compute_ndim( def _set_node_attr(self, node: Node, attr: str, value: Any): if isinstance(value, np.ndarray): value = list(value) - self.graph.nodes[node][attr] = value + self.graph[node][attr] = [value] def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): for node, value in zip(nodes, values, strict=False): - if isinstance(value, np.ndarray): - value = list(value) - self.graph.nodes[node][attr] = value + self.graph[node][attr] = [value] def get_node_attr(self, node: Node, attr: str, required: bool = False): - if required: - return self.graph.nodes[node][attr] - else: - return self.graph.nodes[node].get(attr, None) + return self.graph[int(node)][attr] def get_nodes_attr(self, nodes: Iterable[Node], attr: str, required: bool = False): return [self.get_node_attr(node, attr, required=required) for node in nodes] def _set_edge_attr(self, edge: Edge, attr: str, value: Any): - self.graph.edges[edge][attr] = value + edge_id = self.graph.edge_id(edge[0], edge[1]) + self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]): for edge, value in zip(edges, values, strict=False): - self.graph.edges[edge][attr] = value + edge_id = self.graph.edge_id(edge[0], edge[1]) + self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def get_edge_attr(self, edge: Edge, attr: str, required: bool = False): - if required: - return self.graph.edges[edge][attr] - else: - return self.graph.edges[edge].get(attr, None) + if attr not in self.graph.edge_attr_keys(): + if required: + raise KeyError(attr) + return None + return td_get_single_attr_from_edge(self.graph, edge=edge, attrs=[attr]) def get_edges_attr(self, edges: Iterable[Edge], attr: str, required: bool = False): return [self.get_edge_attr(edge, attr, required=required) for edge in edges] @@ -541,7 +633,7 @@ def enable_features(self, feature_keys: list[str], recompute: bool = True) -> No for key in feature_keys: if key not in self.features: feature, _ = self.annotators.all_features[key] - self.features[key] = feature + self.add_feature(key, feature) # Compute the features if requested if recompute: @@ -564,4 +656,51 @@ def disable_features(self, feature_keys: list[str]) -> None: # Remove from FeatureDict for key in feature_keys: if key in self.features: - del self.features[key] + self.delete_feature(key) + + def add_feature(self, key: str, feature: Feature) -> None: + """Add a feature to the features dictionary and perform graph operations. + + This is the preferred way to add new features as it ensures both the + features dictionary is updated and any necessary graph operations are performed. + + Args: + key: The key for the new feature + feature: The Feature object to add + """ + # Add to the features dictionary + self.features[key] = feature + + # Perform custom graph operations when a feature is added + if feature["feature_type"] == "node" and key not in self.graph.node_attr_keys(): + self.graph.add_node_attr_key( + key, + default_value=feature["default_value"], + dtype=to_polars_dtype(feature["value_type"]), + ) + elif feature["feature_type"] == "edge" and key not in self.graph.edge_attr_keys(): + self.graph.add_edge_attr_key( + key, + default_value=feature["default_value"], + dtype=to_polars_dtype(feature["value_type"]), + ) + + def delete_feature(self, key: str) -> None: + """Delete a feature from the features dictionary and perform graph operations. + + This is the preferred way to delete features as it ensures both the + features dictionary is updated and any necessary graph operations are performed. + + Args: + key: The key for the feature to delete + """ + # Remove from the features dictionary + del self.features[key] + + # Perform custom graph operations when a feature is deleted + if feature := self.annotators.all_features.get(key): + feature_type = feature[0]["feature_type"] + if feature_type == "node" and key in self.graph.node_attr_keys(): + self.graph.remove_node_attr_key(key) + elif feature_type == "edge" and key in self.graph.edge_attr_keys(): + self.graph.remove_edge_attr_key(key) diff --git a/src/funtracks/import_export/_import_segmentation.py b/src/funtracks/import_export/_import_segmentation.py index ef66d048..f750a741 100644 --- a/src/funtracks/import_export/_import_segmentation.py +++ b/src/funtracks/import_export/_import_segmentation.py @@ -10,10 +10,11 @@ from typing import TYPE_CHECKING import dask.array as da -import networkx as nx import numpy as np +import tracksdata as td from funtracks.import_export.magic_imread import magic_imread +from funtracks.utils.tracksdata_utils import td_relabel_nodes if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -45,11 +46,11 @@ def load_segmentation(segmentation: Path | np.ndarray | da.Array) -> da.Array: def relabel_segmentation( seg_array: da.Array | np.ndarray, - graph: nx.DiGraph, + graph: td.graph.GraphView, node_ids: ArrayLike, seg_ids: ArrayLike, time_values: ArrayLike, -) -> np.ndarray: +) -> tuple[np.ndarray, td.graph.GraphView]: """Relabel segmentation from seg_id to node_id. Handles the case where node_id 0 exists by offsetting all node IDs by 1, @@ -57,13 +58,14 @@ def relabel_segmentation( Args: seg_array: Segmentation array (dask or numpy) - graph: NetworkX graph (modified in-place if node_id 0 exists) + graph: tracksdata GraphView (will be relabeled if node_id 0 exists) node_ids: Array of node IDs seg_ids: Array of segmentation IDs corresponding to each node time_values: Array of time values for each node Returns: - Relabeled segmentation as numpy array with dtype uint64 + Tuple of (relabeled segmentation as numpy array with dtype uint64, + graph (potentially relabeled if node_id 0 existed)) """ # Convert to numpy arrays for processing node_ids = np.asarray(node_ids) @@ -77,8 +79,9 @@ def relabel_segmentation( # in segmentation arrays. We also need to relabel the graph nodes. offset = 1 if 0 in node_ids else 0 if offset: - mapping = {old_id: old_id + offset for old_id in graph.nodes()} - nx.relabel_nodes(graph, mapping, copy=False) + mapping = {old_id: old_id + offset for old_id in graph.node_ids()} + # nx.relabel_nodes modified graph in-place, but td_relabel_nodes returns new graph + graph = td_relabel_nodes(graph, mapping) # Update node_ids array to match node_ids = node_ids + offset @@ -98,7 +101,7 @@ def relabel_segmentation( for seg_id, node_id in seg_to_node.items(): new_segmentation[t][computed_seg[t] == seg_id] = node_id - return new_segmentation + return new_segmentation, graph # TODO: export segmentation with check to relabel to track_id diff --git a/src/funtracks/import_export/_tracks_builder.py b/src/funtracks/import_export/_tracks_builder.py index 92506a44..186da140 100644 --- a/src/funtracks/import_export/_tracks_builder.py +++ b/src/funtracks/import_export/_tracks_builder.py @@ -9,11 +9,10 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal -import geff -import networkx as nx import numpy as np +import tracksdata as td from geff._typing import InMemoryGeff from funtracks.data_model.solution_tracks import SolutionTracks @@ -38,6 +37,10 @@ validate_node_name_map, validate_spatial_dims, ) +from funtracks.utils.tracksdata_utils import ( + add_masks_and_bboxes_to_graph, + create_empty_graphview_graph, +) if TYPE_CHECKING: import pandas as pd @@ -87,7 +90,7 @@ class TracksBuilder(ABC): TIME_ATTR = "time" - def __init__(self): + def __init__(self) -> None: """Initialize builder state.""" # State transferred between steps self.in_memory_geff: InMemoryGeff | None = None @@ -383,7 +386,9 @@ def validate(self) -> None: # Validate graph structure and optional properties validate_in_memory_geff(self.in_memory_geff) - def construct_graph(self) -> nx.DiGraph: + def construct_graph( + self, node_name_map: dict[str, str | list[str]] | None = None + ) -> td.graph.GraphView: """Construct NetworkX graph from validated InMemoryGeff data. Common logic shared across all formats. @@ -396,14 +401,86 @@ def construct_graph(self) -> nx.DiGraph: """ if self.in_memory_geff is None: raise ValueError("No data loaded. Call load_source() first.") - return geff.construct(**self.in_memory_geff) + + if node_name_map is not None: + node_attributes = list(self.in_memory_geff["node_props"].keys()) + node_first_values = [ + self.in_memory_geff["node_props"][key]["values"][0] + for key in node_attributes + ] + + node_default_dtypes = [type(value) for value in node_first_values] + node_default_values = [] + for i, dtype in enumerate(node_default_dtypes): + default_value: Any + if issubclass(dtype, np.integer): + default_value = -1 + elif issubclass(dtype, np.floating): + default_value = 0.0 + elif issubclass(dtype, np.str_): + default_value = "" + elif issubclass(dtype, np.ndarray): + default_value = np.array([0.0 for _ in node_first_values[i]]) + else: + default_value = 0 + node_default_values.append(default_value) + + graph = create_empty_graphview_graph( + node_attributes=list(self.in_memory_geff["node_props"].keys()), + edge_attributes=list(self.in_memory_geff["edge_props"].keys()), + node_default_values=node_default_values, + database=":memory:", + ) + + node_ids = [int(i) for i in self.in_memory_geff["node_ids"]] + node_attrs = [] + for idx in range(len(self.in_memory_geff["node_ids"])): + node_attr = {} + node_attr[td.DEFAULT_ATTR_KEYS.SOLUTION] = 1 # Add default solution value + for key, prop in self.in_memory_geff["node_props"].items(): + # force time key to be "t" in graph + if key == self.TIME_ATTR: + key = "t" + value = prop["values"][idx] + # set missing attribute to None + if prop.get("missing") is not None and prop["missing"][idx]: + value = None + node_attr[key] = value + for key in graph.node_attr_keys(): + if key not in node_attr: + node_attr[key] = None # type: ignore[assignment] + node_attrs.append(node_attr) + + edge_attrs = [] + for idx in range(len(self.in_memory_geff["edge_ids"])): + edge_attr = {} + edge_attr["source_id"] = int(self.in_memory_geff["edge_ids"][idx][0]) + edge_attr["target_id"] = int(self.in_memory_geff["edge_ids"][idx][1]) + edge_attr[td.DEFAULT_ATTR_KEYS.SOLUTION] = 1 # Default solution value + for key, prop in self.in_memory_geff["edge_props"].items(): + value = prop["values"][idx] + if prop.get("missing") is not None and prop["missing"][idx]: + value = None + edge_attr[key] = value + for key in graph.edge_attr_keys(): + if key not in edge_attr: + edge_attr[key] = None # type: ignore[assignment] + edge_attrs.append(edge_attr) + + graph.bulk_add_nodes(nodes=node_attrs, indices=node_ids) + graph.bulk_add_edges(edge_attrs) + + if self.TIME_ATTR != "t": + graph.remove_node_attr_key(self.TIME_ATTR) + + return graph def handle_segmentation( self, - graph: nx.DiGraph, + graph: td.graph.GraphView, segmentation: Path | np.ndarray | None, scale: list[float] | None, - ) -> tuple[np.ndarray | None, list[float] | None]: + ) -> tuple[np.ndarray | None, list[float] | None, td.graph.GraphView]: """Load, validate, and optionally relabel segmentation. Common logic shared across all formats. @@ -414,13 +491,14 @@ def handle_segmentation( scale: Spatial scale for coordinate transformation Returns: - Tuple of (segmentation array, scale) or (None, scale) + Tuple of (segmentation array, scale, graph). The graph may be relabeled + if node_id 0 exists in the original graph. Raises: ValueError: If segmentation validation fails """ if segmentation is None: - return None, scale + return None, scale, graph if self.in_memory_geff is None: raise ValueError("No data loaded. Call load_source() first.") @@ -441,8 +519,8 @@ def handle_segmentation( # Validate segmentation matches graph (only if position is loaded) # If position is not in graph, it will be computed from segmentation - sample_node = next(iter(graph.nodes())) - has_position = "pos" in graph.nodes[sample_node] + # sample_node = next(iter(graph.node_ids())) + has_position = "pos" in graph.node_attr_keys() if has_position: from funtracks.import_export._validation import validate_graph_seg_match @@ -452,7 +530,7 @@ def handle_segmentation( node_props = self.in_memory_geff["node_props"] if "seg_id" not in node_props: # No seg_id property, assume segmentation labels match node IDs - return seg_array.compute(), scale + return seg_array.compute(), scale, graph node_ids = self.in_memory_geff["node_ids"] seg_ids = node_props["seg_id"]["values"] @@ -460,16 +538,15 @@ def handle_segmentation( # Check if any seg_id differs from node_id if np.array_equal(seg_ids, node_ids): # No relabeling needed - return seg_array.compute(), scale + return seg_array.compute(), scale, graph # Relabel segmentation: seg_id -> node_id - time_attr = "time" - time_values = node_props[time_attr]["values"] - new_segmentation = relabel_segmentation( + time_values = node_props[self.TIME_ATTR]["values"] + new_segmentation, graph = relabel_segmentation( seg_array, graph, node_ids, seg_ids, time_values ) - return new_segmentation, scale + return new_segmentation, scale, graph def enable_features( self, @@ -553,6 +630,7 @@ def build( scale: list[float] | None = None, node_features: dict[str, bool] | None = None, edge_features: dict[str, bool] | None = None, + node_name_map: dict[str, str | list[str]] | None = None, ) -> SolutionTracks: """Orchestrate the full construction process. @@ -562,6 +640,7 @@ def build( scale: Optional spatial scale node_features: Optional node features to enable/load edge_features: Optional edge features to enable/load + node_name_map: Optional node_name_map to override self.node_name_map Returns: Fully constructed SolutionTracks object @@ -625,22 +704,28 @@ def build( self.validate() # 4. Construct graph - graph = self.construct_graph() + graph = self.construct_graph(node_name_map) # 5. Handle segmentation - segmentation_array, scale = self.handle_segmentation(graph, segmentation, scale) + segmentation_array, scale, graph = self.handle_segmentation( + graph, segmentation, scale + ) + + # 6. Add segmentation to the graph + if segmentation_array is not None: + graph = add_masks_and_bboxes_to_graph(graph, segmentation_array) + graph.update_metadata(segmentation_shape=segmentation_array.shape) - # 6. Create SolutionTracks + # 7. Create SolutionTracks tracks = SolutionTracks( graph=graph, - segmentation=segmentation_array, pos_attr="pos", time_attr=self.TIME_ATTR, ndim=self.ndim, scale=scale, ) - # 7. Enable and register features + # 8. Enable and register features if node_features is not None: self.enable_features(tracks, node_features, feature_type="node") if edge_features is not None: diff --git a/src/funtracks/import_export/_utils.py b/src/funtracks/import_export/_utils.py index 18e888d9..334cf42d 100644 --- a/src/funtracks/import_export/_utils.py +++ b/src/funtracks/import_export/_utils.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING -import networkx as nx import numpy as np +import tracksdata as td from funtracks.data_model.tracks import Tracks @@ -70,7 +70,9 @@ def infer_dtype_from_array(arr: ArrayLike) -> ValueType: return "str" -def filter_graph_with_ancestors(graph: nx.DiGraph, nodes_to_keep: set[int]) -> list[int]: +def filter_graph_with_ancestors( + graph: td.graph.GraphView, nodes_to_keep: set[int] +) -> list[int]: """Filter a graph to keep only the nodes in `nodes_to_keep` and their ancestors. Args: @@ -82,10 +84,20 @@ def filter_graph_with_ancestors(graph: nx.DiGraph, nodes_to_keep: set[int]) -> l in `nodes_to_keep` and their ancestors. """ all_nodes_to_keep = set(nodes_to_keep) + import rustworkx as rx - for node in nodes_to_keep: - ancestors = nx.ancestors(graph, node) - all_nodes_to_keep.update(ancestors) + # Map external node ID to internal RustWorkX index + nodes_to_keep_internal = graph._vectorized_map_to_local(list(nodes_to_keep)) + + # Collect all internal ancestor IDs + all_ancestors_internal = set() + for internal_node in nodes_to_keep_internal: + ancestors = rx.ancestors(graph.rx_graph, internal_node) + all_ancestors_internal.update(ancestors) + + # Convert ancestor indices back to external node IDs + ancestors_external = graph._vectorized_map_to_external(list(all_ancestors_internal)) + all_nodes_to_keep.update(int(a) for a in ancestors_external) return list(all_nodes_to_keep) @@ -108,7 +120,7 @@ def rename_feature(tracks: Tracks, old_key: str, new_key: str) -> None: # Register it to the feature dictionary, removing old key if necessary if old_key in tracks.features: tracks.features.pop(old_key) - tracks.features[new_key] = feature_dict + tracks.add_feature(new_key, feature_dict) # Update FeatureDict special key attributes if we renamed position or tracklet if tracks.features.position_key == old_key: diff --git a/src/funtracks/import_export/_v1_format.py b/src/funtracks/import_export/_v1_format.py index f6372626..a32166a0 100644 --- a/src/funtracks/import_export/_v1_format.py +++ b/src/funtracks/import_export/_v1_format.py @@ -9,6 +9,10 @@ import numpy as np from funtracks.features import FeatureDict +from funtracks.utils.tracksdata_utils import ( + add_masks_and_bboxes_to_graph, + convert_graph_nx_to_td, +) if TYPE_CHECKING: from ..data_model import SolutionTracks, Tracks @@ -18,88 +22,6 @@ ATTRS_FILE = "attrs.json" -def _save_v1_tracks(tracks: Tracks, directory: Path) -> None: - """Only used for testing backward compatibility! - - Currently, saves the graph as a json file in networkx node link data format, - saves the segmentation as a numpy npz file, and saves the time and position - attributes and scale information in an attributes json file. - Will make the directory if it doesn't exist. - - Args: - tracks (Tracks): the tracks to save - directory (Path): The directory to save the tracks in. - """ - directory.mkdir(exist_ok=True, parents=True) - _save_graph(tracks, directory) - _save_seg(tracks, directory) - _save_attrs(tracks, directory) - - -def _save_graph(tracks: Tracks, directory: Path) -> None: - """Save the graph to file. Currently uses networkx node link data - format (and saves it as json). - - Args: - tracks (Tracks): the tracks to save the graph of - directory (Path): The directory in which to save the graph file. - """ - graph_file = directory / GRAPH_FILE - graph_data = nx.node_link_data(tracks.graph, edges="links") - - def convert_np_types(data): - """Recursively convert numpy types to native Python types.""" - - if isinstance(data, dict): - return {key: convert_np_types(value) for key, value in data.items()} - elif isinstance(data, list): - return [convert_np_types(item) for item in data] - elif isinstance(data, np.ndarray): - return data.tolist() # Convert numpy arrays to Python lists - elif isinstance(data, np.integer): - return int(data) # Convert numpy integers to Python int - elif isinstance(data, np.floating): - return float(data) # Convert numpy floats to Python float - else: - return data # Return the data as-is if it's already a native Python type - - graph_data = convert_np_types(graph_data) - with open(graph_file, "w") as f: - json.dump(graph_data, f) - - -def _save_seg(tracks: Tracks, directory: Path) -> None: - """Save a segmentation as a numpy array using np.save. In the future, - could be changed to use zarr or other file types. - - Args: - tracks (Tracks): the tracks to save the segmentation of - directory (Path): The directory in which to save the segmentation - """ - if tracks.segmentation is not None: - out_path = directory / SEG_FILE - np.save(out_path, tracks.segmentation) - - -def _save_attrs(tracks: Tracks, directory: Path) -> None: - """Save the and scale, ndim, and features in a json file in the given directory. - - Args: - tracks (Tracks): the tracks to save the attributes of - directory (Path): The directory in which to save the attributes - """ - out_path = directory / ATTRS_FILE - attrs_dict = { - "scale": tracks.scale - if not isinstance(tracks.scale, np.ndarray) - else tracks.scale.tolist(), - "ndim": tracks.ndim, - "features": tracks.features.dump_json(), - } - with open(out_path, "w") as f: - json.dump(attrs_dict, f) - - def load_v1_tracks( directory: Path, seg_required: bool = False, solution: bool = False ) -> Tracks | SolutionTracks: @@ -119,7 +41,7 @@ def load_v1_tracks( Tracks: A tracks object loaded from the given directory """ graph_file = directory / GRAPH_FILE - graph = _load_graph(graph_file) + graph_nx = _load_graph(graph_file) seg_file = directory / SEG_FILE seg = _load_seg(seg_file, seg_required=seg_required) @@ -127,6 +49,13 @@ def load_v1_tracks( attrs_file = directory / ATTRS_FILE attrs = _load_attrs(attrs_file) + graph_td = convert_graph_nx_to_td(graph_nx) + + # Add mask and bbox attributes to graph if segmentation is available + if seg is not None: + graph_td = add_masks_and_bboxes_to_graph(graph_td, seg) + graph_td.update_metadata(segmentation_shape=seg.shape) + # filtering the warnings because the default values of time_attr and pos_attr are # not None. Therefore, new style Tracks attrs that have features instead of # pos_attr and time_attr will always trigger the warning. Updating default values @@ -141,9 +70,9 @@ def load_v1_tracks( ) tracks: Tracks if solution: - tracks = SolutionTracks(graph, seg, **attrs) + tracks = SolutionTracks(graph_td, **attrs) else: - tracks = Tracks(graph, seg, **attrs) + tracks = Tracks(graph_td, **attrs) return tracks diff --git a/src/funtracks/import_export/_validation.py b/src/funtracks/import_export/_validation.py index 5031f4f4..6d6ab4d9 100644 --- a/src/funtracks/import_export/_validation.py +++ b/src/funtracks/import_export/_validation.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from warnings import warn -import networkx as nx +import tracksdata as td from geff._typing import InMemoryGeff from geff.validate.graph import ( validate_no_repeated_edges, @@ -22,7 +22,7 @@ def validate_graph_seg_match( - graph: nx.DiGraph, + graph: td.graph.GraphView, segmentation: da.Array, scale: list[float], position_attr: list[str], @@ -34,7 +34,7 @@ def validate_graph_seg_match( of the segmentation to match node id values is required. Args: - graph: NetworkX graph with standard keys + graph: tracksdata graph with standard keys segmentation: Segmentation data (dask array) scale: Scaling information (pixel to world coordinates) position_attr: Position keys (e.g., ["y", "x"] or ["z", "y", "x"]) @@ -51,20 +51,20 @@ def validate_graph_seg_match( ) # Get the last node for validation - node_ids = list(graph.nodes()) + node_ids = list(graph.node_ids()) if not node_ids: raise ValueError("Graph has no nodes") last_node_id = node_ids[-1] - last_node_data = graph.nodes[last_node_id] + last_node_data = graph[last_node_id] # Check if seg_id exists; if not, assume it matches node_id - seg_id = last_node_data.get(SEG_KEY, last_node_id) + seg_id = last_node_data["seg_id"] # Get the coordinates for the last node (using standard keys) # Position may be stored as composite "pos" attribute or separate y/x/z attributes - coord = [int(last_node_data["time"])] - if "pos" in last_node_data: + coord = [int(last_node_data["t"])] + if "pos" in graph.node_attr_keys(): # Composite position: [z, y, x] or [y, x] pos = last_node_data["pos"] coord.extend(pos) diff --git a/src/funtracks/import_export/csv/_export.py b/src/funtracks/import_export/csv/_export.py index b2eb5c6d..e865b57d 100644 --- a/src/funtracks/import_export/csv/_export.py +++ b/src/funtracks/import_export/csv/_export.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +import polars as pl from .._utils import filter_graph_with_ancestors @@ -86,7 +87,7 @@ def convert_numpy_to_python(value): # Determine which nodes to export if node_ids is None: - node_to_keep = tracks.graph.nodes() + node_to_keep = tracks.graph.node_ids() else: node_to_keep = filter_graph_with_ancestors(tracks.graph, node_ids) @@ -104,6 +105,10 @@ def convert_numpy_to_python(value): feature_value = tracks.get_node_attr(node_id, feature_name) if isinstance(feature_value, list | tuple): features.extend(feature_value) + elif feature_name == "pos" and isinstance( + feature_value, pl.series.Series + ): + features.extend(feature_value.to_list()) else: features.append(feature_value) row = [node_id, parent_id, *features] diff --git a/src/funtracks/import_export/csv/_import.py b/src/funtracks/import_export/csv/_import.py index 85dfb45b..f4025bf2 100644 --- a/src/funtracks/import_export/csv/_import.py +++ b/src/funtracks/import_export/csv/_import.py @@ -268,9 +268,12 @@ def tracks_from_df( # Auto-infer name mapping from DataFrame columns builder.prepare(df) + # instead of a separate segmentation array + return builder.build( df, segmentation, scale=scale, node_features=node_features, + node_name_map=builder.node_name_map, ) diff --git a/src/funtracks/import_export/geff/_export.py b/src/funtracks/import_export/geff/_export.py index 2edff0a3..2d7644b3 100644 --- a/src/funtracks/import_export/geff/_export.py +++ b/src/funtracks/import_export/geff/_export.py @@ -6,10 +6,10 @@ Literal, ) -import geff import geff_spec -import networkx as nx import numpy as np +import polars as pl +import tracksdata as td from geff_spec import GeffMetadata from funtracks.utils import remove_tilde, setup_zarr_array, setup_zarr_group @@ -29,7 +29,7 @@ def export_to_geff( node_ids: set[int] | None = None, zarr_format: Literal[2, 3] = 2, ): - """Export the Tracks nxgraph to geff. + """Export the Tracks graph to geff. Args: tracks (Tracks): Tracks object containing a graph to save. @@ -71,11 +71,23 @@ def export_to_geff( if tracks.scale is None: tracks.scale = (1.0,) * tracks.ndim + # Create axes metadata + axes = [] + for name, axis_type, scale in zip(axis_names, axis_types, tracks.scale, strict=True): + axes.append( + { + "name": name, + "type": axis_type, + "scale": scale, + } + ) + metadata = GeffMetadata( geff_version=geff_spec.__version__, - directed=isinstance(graph, nx.DiGraph), + directed=True, node_props_metadata={}, edge_props_metadata={}, + axes=axes, ) # Save segmentation if present @@ -135,23 +147,14 @@ def export_to_geff( # Filter the graph if node_ids is provided if node_ids is not None: - graph = graph.subgraph(nodes_to_keep).copy() + graph = graph.filter(node_ids=nodes_to_keep).subgraph() # Save the graph in a 'tracks' folder tracks_path = directory / "tracks" - geff.write( - graph=graph, - store=tracks_path, - metadata=metadata, - axis_names=axis_names, - axis_types=axis_types, - axis_scales=tracks.scale, - overwrite=overwrite, - zarr_format=zarr_format, - ) + graph.to_geff(geff_store=tracks_path, geff_metadata=metadata, zarr_format=zarr_format) -def split_position_attr(tracks: Tracks) -> tuple[nx.DiGraph, list[str] | None]: +def split_position_attr(tracks: Tracks) -> tuple[td.graph.GraphView, list[str] | None]: # TODO: this exists in unsqueeze in geff somehow? """Spread the spatial coordinates to separate node attrs in order to export to geff format. @@ -161,23 +164,44 @@ def split_position_attr(tracks: Tracks) -> tuple[nx.DiGraph, list[str] | None]: converted. Returns: - tuple[nx.DiGraph, list[str]]: graph with a separate positional attribute for each - coordinate, and the axis names used to store the separate attributes + tuple[td.graph.GraphView, list[str] | None]: graph with a separate positional + attribute for each coordinate, and the axis names used to store the + separate attributes """ pos_key = tracks.features.position_key if isinstance(pos_key, str): # Position is stored as a single attribute, need to split - new_graph = tracks.graph.copy() - new_keys = ["y", "x"] - if tracks.ndim == 4: - new_keys.insert(0, "z") - for _, attrs in new_graph.nodes(data=True): - pos = attrs.pop(pos_key) - for i in range(len(new_keys)): - attrs[new_keys[i]] = pos[i] - + new_graph = tracks.graph.detach() + new_graph = new_graph.filter().subgraph() + + # Register new attribute keys + new_graph.add_node_attr_key("x", default_value=0.0, dtype=pl.Float64) + new_graph.add_node_attr_key("y", default_value=0.0, dtype=pl.Float64) + + # Get all position values at once + pos_values = new_graph.node_attrs()["pos"].to_numpy() + ndim = pos_values.shape[1] + + if ndim == 2: + new_keys = ["y", "x"] + new_graph.update_node_attrs( + attrs={"x": pos_values[:, 1], "y": pos_values[:, 0]}, + node_ids=new_graph.node_ids(), + ) + elif ndim == 3: + new_keys = ["z", "y", "x"] + new_graph.add_node_attr_key("z", default_value=0.0, dtype=pl.Float64) + new_graph.update_node_attrs( + attrs={ + "x": pos_values[:, 2], + "y": pos_values[:, 1], + "z": pos_values[:, 0], + }, + node_ids=new_graph.node_ids(), + ) + new_graph.remove_node_attr_key(pos_key) return new_graph, new_keys elif pos_key is not None: # Position is already split into separate attributes diff --git a/src/funtracks/import_export/geff/_import.py b/src/funtracks/import_export/geff/_import.py index 6a7ddd16..680e451d 100644 --- a/src/funtracks/import_export/geff/_import.py +++ b/src/funtracks/import_export/geff/_import.py @@ -287,4 +287,5 @@ def import_from_geff( scale=scale, node_features=node_features, edge_features=edge_features, + node_name_map=node_name_map, ) diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py index 701fcc04..c830fbe7 100644 --- a/src/funtracks/user_actions/user_add_edge.py +++ b/src/funtracks/user_actions/user_add_edge.py @@ -53,7 +53,8 @@ def __init__( forceable=True, ) else: - merge_edge = list(self.tracks.graph.in_edges(target))[0] + pred = next(iter(self.tracks.graph.predecessors(target))) + merge_edge = (pred, target) warnings.warn( f"Removing edge {merge_edge} to add new edge without merging.", stacklevel=2, diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index 82da2787..44b714a3 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -6,6 +6,7 @@ import numpy as np from funtracks.exceptions import InvalidActionError +from funtracks.utils.tracksdata_utils import pixels_to_td_mask from ..actions._base import ActionGroup from ..actions.add_delete_edge import AddEdge, DeleteEdge @@ -91,7 +92,7 @@ def __init__( pred, succ = self.tracks.get_track_neighbors(track_id, time) # check if you are adding a node to a track that divided previously - if pred is not None and self.tracks.graph.out_degree(pred) == 2: + if pred is not None and self.tracks.graph.out_degree(int(pred)) == 2: if not force: raise InvalidActionError( "Cannot add node here - upstream division event detected.", @@ -107,7 +108,8 @@ def __init__( # downstream elif succ is not None: # check pred of succ - pred_of_succ = next(self.tracks.graph.predecessors(succ), None) + preds = self.tracks.graph.predecessors(succ) + pred_of_succ = preds[0] if preds else None if ( pred_of_succ is not None and self.tracks.graph.out_degree(pred_of_succ) == 2 @@ -125,7 +127,8 @@ def __init__( if pred is not None and succ is not None: self.actions.append(DeleteEdge(tracks, (pred, succ))) # add predecessor and successor edges - self.actions.append(AddNode(tracks, node, attributes, pixels)) + mask = pixels_to_td_mask(pixels, self.tracks.ndim) if pixels is not None else None + self.actions.append(AddNode(tracks, node, attributes, mask)) if pred is not None: self.actions.append(AddEdge(tracks, (pred, node))) if succ is not None: diff --git a/src/funtracks/user_actions/user_delete_node.py b/src/funtracks/user_actions/user_delete_node.py index c5d5cec7..9e6284da 100644 --- a/src/funtracks/user_actions/user_delete_node.py +++ b/src/funtracks/user_actions/user_delete_node.py @@ -4,6 +4,10 @@ import numpy as np +from funtracks.utils.tracksdata_utils import ( + pixels_to_td_mask, +) + from ..actions._base import ActionGroup from ..actions.add_delete_edge import AddEdge, DeleteEdge from ..actions.add_delete_node import DeleteNode @@ -45,4 +49,9 @@ def __init__( self.actions.append(AddEdge(tracks, (predecessor, successor))) # delete node - self.actions.append(DeleteNode(tracks, node, pixels=pixels)) + mask = ( + pixels_to_td_mask(pixels, ndim=self.tracks.ndim) + if pixels is not None + else None + ) + self.actions.append(DeleteNode(tracks, node, mask=mask)) diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index ab35337a..73769f1f 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -12,6 +12,8 @@ if TYPE_CHECKING: from funtracks.data_model import SolutionTracks +from funtracks.utils.tracksdata_utils import pixels_to_td_mask + class UserUpdateSegmentation(ActionGroup): def __init__( @@ -50,12 +52,15 @@ def __init__( continue time = pixels[0][0] # check if all pixels of old_value are removed - # TODO: this assumes the segmentation is already updated, but then we can't - # recover the pixels, so we have to pass them here for undo purposes - if np.sum(self.tracks.segmentation[time] == old_value) == 0: + mask_pixels = pixels_to_td_mask(pixels, self.tracks.ndim) + mask_old_value = self.tracks.graph[old_value]["mask"] + # If pixels fully overlaps with old_value mask, delete node + if mask_pixels.intersection(mask_old_value) == mask_old_value.mask.sum(): self.actions.append(UserDeleteNode(tracks, old_value, pixels=pixels)) else: - self.actions.append(UpdateNodeSeg(tracks, old_value, pixels, added=False)) + self.actions.append( + UpdateNodeSeg(tracks, old_value, mask_pixels, added=False) + ) if new_value != 0: all_pixels = tuple( np.concatenate([pixels[dim] for pixels, _ in updated_pixels]) @@ -66,8 +71,9 @@ def __init__( ) time = all_pixels[0][0] if self.tracks.graph.has_node(new_value): + mask_pixels = pixels_to_td_mask(all_pixels, self.tracks.ndim) self.actions.append( - UpdateNodeSeg(tracks, new_value, all_pixels, added=True) + UpdateNodeSeg(tracks, new_value, mask_pixels, added=True) ) else: time_key = tracks.features.time_key diff --git a/src/funtracks/utils/__init__.py b/src/funtracks/utils/__init__.py index 5985c241..6635ac11 100644 --- a/src/funtracks/utils/__init__.py +++ b/src/funtracks/utils/__init__.py @@ -9,8 +9,10 @@ setup_zarr_array, setup_zarr_group, ) +from .tracksdata_utils import create_empty_graphview_graph __all__ = [ + "create_empty_graphview_graph", "detect_zarr_spec_version", "get_store_path", "is_zarr_v3", diff --git a/src/funtracks/utils/tracksdata_utils.py b/src/funtracks/utils/tracksdata_utils.py new file mode 100644 index 00000000..62010903 --- /dev/null +++ b/src/funtracks/utils/tracksdata_utils.py @@ -0,0 +1,644 @@ +import tempfile +import uuid +from collections.abc import Sequence +from typing import Any + +import networkx as nx +import numpy as np +import polars as pl +import scipy.ndimage as ndi +import tracksdata as td +from polars.testing import assert_frame_equal +from tracksdata.nodes._mask import Mask + + +def to_polars_dtype(dtype_or_value: str | Any) -> pl.DataType: + """Convert a type string or value to polars dtype. + + Args: + dtype_or_value: Either a type string ("int", "float", "str", "bool") + or a value whose type should be inferred + + Returns: + Corresponding polars dtype + + Raises: + ValueError: If the type is not supported + + Examples: + >>> to_polars_dtype("int") + Int64 + >>> to_polars_dtype(5) + Int64 + >>> to_polars_dtype(np.int64(5)) + Int64 + >>> to_polars_dtype("") # String value + String + """ + # Check if it's a known type string first + type_string_mapping = { + "str": pl.String, + "int": pl.Int64, + "float": pl.Float64, + "bool": pl.Boolean, + "datetime": pl.Datetime, + "date": pl.Date, + } + + if dtype_or_value in type_string_mapping: + return type_string_mapping[dtype_or_value] + + # If it's a string but not a type name, try as polars type name (backward compat) + if isinstance(dtype_or_value, str): + try: + return getattr(pl, dtype_or_value) + except AttributeError: + # It's a string value, not a type name - return String dtype + return pl.String + + # Otherwise, infer from the value's type + if isinstance(dtype_or_value, (bool, np.bool_)): + return pl.Boolean + elif isinstance(dtype_or_value, (int, np.integer)): + return pl.Int64 + elif isinstance(dtype_or_value, (float, np.floating)): + return pl.Float64 + else: + raise ValueError(f"Unsupported type: {type(dtype_or_value)}") + + +def create_empty_graphview_graph( + node_attributes: list[str] | None = None, + edge_attributes: list[str] | None = None, + node_default_values: list[Any] | None = None, + edge_default_values: list[Any] | None = None, + database: str | None = None, + position_attrs: list[str] | None = None, + ndim: int = 3, +) -> td.graph.GraphView: + """ + Create an empty tracksdata GraphView with standard node and edge attributes. + Parameters + ---------- + node_attributes : list[str] | None + List of node attribute names to include. (providing time attribute not necessary) + edge_attributes : list[str] | None + List of edge attribute names to include. + node_default_values : list[Any] | None + List of default values for each node attribute. + Must match length of node_attributes. + edge_default_values : list[Any] | None + List of default values for each edge attribute. + Must match length of edge_attributes. + database : str | None + Path to the SQLite database file. If None, creates a unique temporary file. + Use ':memory:' for in-memory database (may cause issues with pickling in pytest). + position_attrs : list[str] | None + List of position attribute names, e.g. ['pos'] or ['x', 'y', 'z']. + Defaults to ['pos'] if None. + ndim : int + Number of dimensions including time, so 2D+T dataset has ndim = 3. + Defaults to 3 (2D+time). + + Returns + ------- + td.graph.GraphView + An empty tracksdata GraphView with standard node and edge attributes. + """ + if position_attrs is None: + position_attrs = ["pos"] + + # Generate unique database path if not specified + if database is None: + temp_dir = tempfile.gettempdir() + unique_id = uuid.uuid4().hex[:8] + database = f"{temp_dir}/funtracks_test_{unique_id}.db" + + if node_default_values is not None: + assert len(node_default_values) == len(node_attributes or []), ( + "Length of node_default_values must match length of node_attributes" + ) + else: + node_default_values = [0.0] * len(node_attributes or []) + + if edge_default_values is not None: + assert len(edge_default_values) == len(edge_attributes or []), ( + "Length of edge_default_values must match length of edge_attributes" + ) + else: + edge_default_values = [0.0] * len(edge_attributes or []) + + kwargs = { + "drivername": "sqlite", + "database": database, + "overwrite": True, + } + graph_sql = td.graph.SQLGraph(**kwargs) + + # Add standard node and edge attributes + if "pos" in (node_attributes or []) or any( + attr in (node_attributes or []) for attr in position_attrs + ): + if "pos" in position_attrs: + graph_sql.add_node_attr_key("pos", pl.Array(pl.Float64, ndim - 1)) + else: + if "x" in position_attrs: + graph_sql.add_node_attr_key("x", default_value=0.0, dtype=pl.Float64) + if "y" in position_attrs: + graph_sql.add_node_attr_key("y", default_value=0.0, dtype=pl.Float64) + if "z" in position_attrs: + graph_sql.add_node_attr_key("z", default_value=0.0, dtype=pl.Float64) + if "mask" in (node_attributes or []): + graph_sql.add_node_attr_key("mask", pl.Object) + if "bbox" in (node_attributes or []): + graph_sql.add_node_attr_key("bbox", pl.Array(pl.Int64, 2 * (ndim - 1))) + if "track_id" in (node_attributes or []): + graph_sql.add_node_attr_key("track_id", default_value=-1, dtype=pl.Int64) + + for attr in node_attributes or []: + if attr not in graph_sql.node_attr_keys(): + default_value = node_default_values[(node_attributes or []).index(attr)] + graph_sql.add_node_attr_key( + attr, + default_value=default_value + if not isinstance(default_value, np.ndarray) + else None, + dtype=to_polars_dtype(default_value) + if not isinstance(default_value, np.ndarray) + else pl.Array(pl.Float64, len(default_value)), # type: ignore + ) + + for attr in edge_attributes or []: + if attr not in graph_sql.edge_attr_keys(): + default_value = edge_default_values[(edge_attributes or []).index(attr)] + graph_sql.add_edge_attr_key( + attr, + default_value=default_value, + dtype=to_polars_dtype(default_value), + ) + graph_sql.add_node_attr_key( + td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1, dtype=pl.Int64 + ) + graph_sql.add_edge_attr_key( + td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1, dtype=pl.Int64 + ) + + graph_td_sub = graph_sql.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + + return graph_td_sub + + +def assert_node_attrs_equal_with_masks( + object1, object2, check_column_order: bool = False, check_row_order: bool = False +): + """ + Fully compare the content of two graphs (node attributes and Masks) + """ + + if isinstance(object1, td.graph.GraphView) and ( + isinstance(object2, td.graph.GraphView) + ): + node_attrs1 = object1.node_attrs() + node_attrs2 = object2.node_attrs() + elif isinstance(object1, pl.DataFrame) and isinstance(object2, pl.DataFrame): + node_attrs1 = object1 + node_attrs2 = object2 + else: + raise ValueError( + "Both objects must be either tracksdata graphs or polars DataFrames" + ) + + # Check all fields, except masks + assert_frame_equal( + node_attrs1.drop("mask"), + node_attrs2.drop("mask"), + check_column_order=check_column_order, + check_row_order=check_row_order, + check_dtypes=False, + ) + # Check masks separately + for node in node_attrs1["node_id"]: + mask1 = node_attrs1.filter(pl.col("node_id") == node)["mask"].item() + mask2 = node_attrs2.filter(pl.col("node_id") == node)["mask"].item() + assert np.array_equal(mask1.bbox, mask2.bbox) + assert np.array_equal(mask1.mask, mask2.mask) + + +def pixels_to_td_mask( + pix: tuple[np.ndarray, ...], + ndim: int, + scale: list[float] | None = None, + include_area: bool = False, +) -> Mask | tuple[Mask, float]: + """ + Convert pixel coordinates to tracksdata mask format. + + Args: + pix: Pixel coordinates for 1 node! + ndim: Number of dimensions (2D or 3D). + scale: Scale factors for each dimension, used for area calculation + include_area: Whether to compute and return the area. + + Returns: + Mask | tuple[Mask, float]: A tuple containing the + tracksdata mask and the area if include_area is True. + Otherwise, just the tracksdata mask. + """ + + if include_area and scale is None: + raise ValueError("Scale must be provided when include_area is True.") + + spatial_dims = ndim - 1 # Handle both 2D and 3D + + # Calculate position and bounding box more efficiently + bbox = np.zeros(2 * spatial_dims, dtype=int) + + # Calculate bbox and shape in one pass + for dim in range(spatial_dims): + pix_dim = dim + 1 + min_val = np.min(pix[pix_dim]) + max_val = np.max(pix[pix_dim]) + bbox[dim] = min_val + bbox[dim + spatial_dims] = max_val + 1 + + # Calculate mask shape from bbox + mask_shape = bbox[spatial_dims:] - bbox[:spatial_dims] + + # Convert coordinates to mask-local coordinates + local_coords = [pix[dim + 1] - bbox[dim] for dim in range(spatial_dims)] + mask_array = np.zeros(mask_shape, dtype=bool) + mask_array[tuple(local_coords)] = True + mask = Mask(mask_array, bbox=bbox) + + if include_area: + area = np.sum(mask_array) + if scale is not None: + area *= np.prod(scale[1:]) + return mask, area + else: + return mask + + +def td_mask_to_pixels(mask: Mask, time: int, ndim: int) -> tuple[np.ndarray, ...]: + """ + Convert tracksdata mask to pixel coordinates. + + This is the inverse of pixels_to_td_mask. + + Args: + mask: Tracksdata Mask object with .mask (boolean array) and .bbox attributes + time: Time point for this mask + ndim: Number of dimensions (3 for 2D+time, 4 for 3D+time) + + Returns: + Tuple of numpy arrays: (time_array, *spatial_coords) + For 2D: (t, y, x) where each is a 1D array of pixel coordinates + For 3D: (t, z, y, x) where each is a 1D array of pixel coordinates + + Example: + >>> mask = Mask(np.array([[True, False], [False, True]]), + ... bbox=np.array([10, 20, 12, 22])) + >>> pixels = td_mask_to_pixels(mask, time=5, ndim=3) + >>> # Returns: (array([5, 5]), array([10, 11]), array([20, 21])) + """ + spatial_dims = ndim - 1 + + # Find all True pixels in the local mask + local_coords = np.nonzero(mask.mask) + + # Convert local coordinates to global coordinates by adding bbox offset + global_coords = [] + for dim in range(spatial_dims): + global_coords.append(local_coords[dim] + mask.bbox[dim]) + + # Create time array with same length as spatial coordinates + num_pixels = len(local_coords[0]) + time_array = np.full(num_pixels, time, dtype=int) + + # Return as tuple: (time, spatial_dim_0, spatial_dim_1, ...) + return (time_array, *global_coords) + + +def segmentation_to_masks( + segmentation: np.ndarray, +) -> list[tuple[int, int, Mask]]: + """Convert a segmentation array to individual masks and bounding boxes. + + Parameters + ---------- + segmentation : np.ndarray + Segmentation array of shape (T, Z, Y, X) or (T, Y, X) + Each unique value represents a different segment/object. + + Returns + ------- + list[tuple[int, int, Mask]] + List of tuples, one per segment, containing: + - label (int): original label ID + - time (int): time point + - mask (Mask): tracksdata Mask object with boolean mask and bbox + """ + results = [] + + # Process each time point + for t in range(segmentation.shape[0]): + time_slice = segmentation[t] + + # Get unique labels + labels = np.unique(time_slice) + labels = labels[labels != 0] + + # Find objects for each label + for label in labels: + # Create binary mask for this label + binary_mask = time_slice == label + + # Find bounding box using scipy (same as Ultrack uses) + slices = ndi.find_objects(binary_mask.astype(int))[0] + + if slices is None: + continue + + # Extract the cropped mask and ensure C-contiguous for blosc2 serialization + cropped_mask = np.ascontiguousarray(binary_mask[slices]) + + # Convert slices to bbox format (min_*, max_*) + ndim = len(slices) + bbox = np.array( + [slices[i].start for i in range(ndim)] # min coordinates + + [slices[i].stop for i in range(ndim)] # max coordinates + ) + + # Create Mask object + mask = Mask(cropped_mask, bbox=bbox) + + results.append((int(label), t, mask)) + + return results + + +def add_masks_and_bboxes_to_graph( + graph: td.graph.GraphView, + segmentation: np.ndarray, +) -> td.graph.GraphView: + """Add mask and bbox attributes to graph nodes from segmentation. + + Parameters + ---------- + graph : td.graph.GraphView + Graph to add attributes to + segmentation : np.ndarray + Segmentation array of shape (T, Z, Y, X) or (T, Y, X) + + Returns + ------- + td.graph.GraphView + Graph with 'mask' and 'bbox' attributes added to nodes + """ + + # Convert segmentation to masks and bounding boxes + list_of_masks = segmentation_to_masks(segmentation) + + # Add 'mask' and 'bbox' attributes to graph nodes + graph.add_node_attr_key("mask", pl.Object) + graph.add_node_attr_key("bbox", pl.Array(pl.Int64, 2 * (segmentation.ndim - 1))) + + for label, _, mask in list_of_masks: + if graph.has_node(label): + graph[label]["mask"] = [mask] + graph[label]["bbox"] = [mask.bbox] + + return graph + + +def td_get_single_attr_from_edge(graph, edge: tuple[int, int], attrs: Sequence[str]): + """Get a single attribute from a edge in a tracksdata graph.""" + + # TODO Teun: later opdate to: graph.edges[edge_id][attr] (after td update) + item = graph.filter(node_ids=[edge[0], edge[1]]).edge_attrs()[attrs].item() + return item + + +def td_relabel_nodes(graph, mapping: dict[int, int]) -> td.graph.SQLGraph: + """Relabel nodes in a tracksdata graph according to a mapping. + + Args: + graph: A tracksdata graph + mapping: Dictionary mapping old node IDs to new node IDs + + Returns: + A new SQLGraph with relabeled nodes + """ + + # For IndexedRXGraph or SQLGraph + old_graph = graph + + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + new_graph = td.graph.SQLGraph(**kwargs) + + # Copy attribute key registrations with defaults and dtypes + node_schemas = graph._node_attr_schemas() + for key, schema in node_schemas.items(): + if key not in ["node_id", "t"]: # Skip system columns + new_graph.add_node_attr_key( + key, default_value=schema.default_value, dtype=schema.dtype + ) + + edge_schemas = graph._edge_attr_schemas() + for key, schema in edge_schemas.items(): + if key not in ["edge_id", "source_id", "target_id"]: # Skip system columns + new_graph.add_edge_attr_key( + key, default_value=schema.default_value, dtype=schema.dtype + ) + + # Get all data + old_nodes = old_graph.node_attrs() + old_edges = old_graph.edge_attrs() + + # Use the provided mapping + id_mapping = mapping + + # Add nodes with new IDs + for row in old_nodes.iter_rows(named=True): + old_id = row.pop("node_id") + new_id = id_mapping[old_id] + new_graph.add_node(row, index=new_id) + + # Add edges with remapped IDs + for row in old_edges.iter_rows(named=True): + source_id = id_mapping[row["source_id"]] + target_id = id_mapping[row["target_id"]] + attrs = { + k: v for k, v in row.items() if k not in ["edge_id", "source_id", "target_id"] + } + new_graph.add_edge(source_id, target_id, attrs) + + new_graph_sub = new_graph.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + return new_graph_sub + + +def get_node_attr_defaults(graph) -> dict[str, Any]: + """Get node attribute keys and their default values from SQLGraph.""" + # Unwrap GraphView if needed + actual_graph = graph._root if hasattr(graph, "_root") else graph + + defaults = {} + for col in actual_graph.Node.__table__.columns: + col_name = col.name + # Skip system columns + if col_name in ["node_id", "t"]: + continue + + # Extract default value from SQLAlchemy column + default_val = None + if ( + hasattr(col, "default") + and col.default is not None + and hasattr(col.default, "arg") + ): + default_val = col.default.arg + + defaults[col_name] = default_val + return defaults + + +def get_edge_attr_defaults(graph) -> dict[str, Any]: + """Get edge attribute keys and their default values from SQLGraph.""" + # Unwrap GraphView if needed + actual_graph = graph._root if hasattr(graph, "_root") else graph + + defaults = {} + for col in actual_graph.Edge.__table__.columns: + col_name = col.name + # Skip system columns + if col_name in ["edge_id", "source_id", "target_id"]: + continue + + # Extract default value + default_val = None + if ( + hasattr(col, "default") + and col.default is not None + and hasattr(col.default, "arg") + ): + default_val = col.default.arg + + defaults[col_name] = default_val + return defaults + + +def convert_graph_nx_to_td(graph_nx: nx.DiGraph) -> td.graph.GraphView: + """Convert a NetworkX DiGraph to a tracksdata SQLGraph. + + Args: + graph_nx: The NetworkX DiGraph to convert. + + Returns: + A tracksdata SQLGraph representing the same graph. + """ + + # Initialize an empty tracksdata SQLGraph + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + graph_td = td.graph.SQLGraph(**kwargs) + + # Get all nodes and edges with attributes + all_nodes = list(graph_nx.nodes(data=True)) + all_edges = list(graph_nx.edges(data=True)) + + # Add node attribute keys to tracksdata graph + for attr, value in all_nodes[0][1].items(): + if attr not in graph_td.node_attr_keys(): + default_value: Any # mypy necessities + dtype: pl.DataType + if isinstance(value, list): + # Array type - always use Float64 for numeric arrays from NetworkX + # since NetworkX doesn't enforce type consistency across nodes + default_value = None + dtype = pl.Array(pl.Float64, len(value)) + else: + # Scalar type - always use Float64 for numeric types from NetworkX + # since NetworkX doesn't enforce type consistency across nodes + if isinstance(value, (int, float, np.integer, np.floating)): + default_value = 0.0 + dtype = pl.Float64 + else: + default_value = "" + dtype = pl.String + graph_td.add_node_attr_key(attr, default_value=default_value, dtype=dtype) + else: + if attr != "t": + raise Warning( + f"Node attribute '{attr}' already exists in " + f"tracksdata graph. Skipping addition." + ) + graph_td.add_node_attr_key( + td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1, dtype=pl.Int64 + ) + + # Add edge attribute keys to tracksdata graph + for attr, value in all_edges[0][2].items(): + if attr not in graph_td.edge_attr_keys(): + if isinstance(value, list): + # Array type - always use Float64 for numeric arrays from NetworkX + default_value = None + dtype = pl.Array(pl.Float64, len(value)) + else: + # Scalar type - always use Float64 for numeric types from NetworkX + if isinstance(value, (int, float, np.integer, np.floating)): + default_value = 0.0 + dtype = pl.Float64 + else: + default_value = "" + dtype = pl.String + graph_td.add_edge_attr_key(attr, default_value=default_value, dtype=dtype) + else: + raise Warning( + f"Edge attribute '{attr}' already exists in tracksdata graph. " + f"Skipping addition." + ) + graph_td.add_edge_attr_key( + td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1, dtype=pl.Int64 + ) + + # Add node attributes + for node_id, attrs in all_nodes: + attrs_copy = dict(attrs) + # Convert lists to numpy arrays to work around tracksdata SQLGraph bug + # where Python lists with floats get truncated + for key, value in attrs_copy.items(): + if isinstance(value, list): + attrs_copy[key] = np.array(value, dtype=np.float64) + attrs_copy[td.DEFAULT_ATTR_KEYS.SOLUTION] = 1 + graph_td.add_node(attrs_copy, index=node_id) + + # Add edges + for source_id, target_id, attrs in all_edges: + attrs_copy = dict(attrs) + # Convert lists to numpy arrays to work around tracksdata SQLGraph bug + for key, value in attrs_copy.items(): + if isinstance(value, list): + attrs_copy[key] = np.array(value, dtype=np.float64) + attrs_copy[td.DEFAULT_ATTR_KEYS.SOLUTION] = 1 + graph_td.add_edge(source_id, target_id, attrs_copy) + + # Create subgraph (GraphView) with only solution nodes and edges + graph_td_sub = graph_td.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + + return graph_td_sub diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..818e0478 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# This file makes the tests directory a Python package +# to support relative imports diff --git a/tests/actions/__init__.py b/tests/actions/__init__.py new file mode 100644 index 00000000..1b68a100 --- /dev/null +++ b/tests/actions/__init__.py @@ -0,0 +1 @@ +# This file makes the tests/actions directory a Python package to support relative imports diff --git a/tests/actions/test_action_history.py b/tests/actions/test_action_history.py index 7e4575d3..4365f22c 100644 --- a/tests/actions/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -1,17 +1,20 @@ -import networkx as nx - from funtracks.actions import AddNode from funtracks.actions.action_history import ActionHistory from funtracks.data_model import SolutionTracks +from funtracks.utils.tracksdata_utils import create_empty_graphview_graph # https://github.com/zaboople/klonk/blob/master/TheGURQ.md def test_action_history(): history = ActionHistory() - tracks = SolutionTracks(nx.DiGraph(), ndim=3, tracklet_attr="track_id") + empty_graph = create_empty_graphview_graph( + node_attributes=["track_id", "pos"], + edge_attributes=[], + ) + tracks = SolutionTracks(empty_graph, ndim=3, tracklet_attr="track_id", time_attr="t") pos = [0, 1] - action1 = AddNode(tracks, node=0, attributes={"time": 0, "pos": pos, "track_id": 1}) + action1 = AddNode(tracks, node=0, attributes={"t": 0, "pos": pos, "track_id": 1}) # empty history has no undo or redo assert not history.undo() @@ -21,7 +24,7 @@ def test_action_history(): history.add_new_action(action1) # undo the action assert history.undo() - assert tracks.graph.number_of_nodes() == 0 + assert tracks.graph.num_nodes() == 0 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 1 assert history._undo_pointer == -1 @@ -31,7 +34,7 @@ def test_action_history(): # redo the action assert history.redo() - assert tracks.graph.number_of_nodes() == 1 + assert tracks.graph.num_nodes() == 1 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 0 assert history._undo_pointer == 0 @@ -41,9 +44,9 @@ def test_action_history(): # undo and then add new action assert history.undo() - action2 = AddNode(tracks, node=10, attributes={"time": 10, "pos": pos, "track_id": 2}) + action2 = AddNode(tracks, node=10, attributes={"t": 10, "pos": pos, "track_id": 2}) history.add_new_action(action2) - assert tracks.graph.number_of_nodes() == 1 + assert tracks.graph.num_nodes() == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 0 @@ -52,7 +55,7 @@ def test_action_history(): # undo back to after action 1 assert history.undo() assert history.undo() - assert tracks.graph.number_of_nodes() == 1 + assert tracks.graph.num_nodes() == 1 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 2 diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py index c3bb9131..06296636 100644 --- a/tests/actions/test_add_delete_edge.py +++ b/tests/actions/test_add_delete_edge.py @@ -1,14 +1,16 @@ -import copy - -import networkx as nx +import numpy as np import pytest from numpy.testing import assert_array_almost_equal +from polars.testing import assert_frame_equal from funtracks.actions import ( ActionGroup, AddEdge, DeleteEdge, ) +from funtracks.utils.tracksdata_utils import ( + td_get_single_attr_from_edge, +) iou_key = "iou" @@ -17,39 +19,49 @@ @pytest.mark.parametrize("with_seg", [True, False]) def test_add_delete_edges(get_tracks, ndim, with_seg): tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - reference_graph = copy.deepcopy(tracks.graph) - reference_seg = copy.deepcopy(tracks.segmentation) + reference_graph = tracks.graph + reference_seg = np.asarray(tracks.segmentation).copy() # Create an empty tracks with just nodes (no edges) - node_graph = nx.create_empty_copy(tracks.graph, with_data=True) - tracks.graph = node_graph + for edge in tracks.graph.edge_list(): + tracks.graph.remove_edge(*edge) edges = [(1, 2), (1, 3), (3, 4), (4, 5)] action = ActionGroup(tracks=tracks, actions=[AddEdge(tracks, edge) for edge in edges]) + + with pytest.raises(ValueError, match="Edge .* already exists in the graph"): + AddEdge(tracks, (1, 2)) + # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions - assert set(tracks.graph.nodes()) == set(reference_graph.nodes()) + assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert_frame_equal( + tracks.graph.edge_attrs(), + reference_graph.edge_attrs(), + check_row_order=False, + check_column_order=False, + ) if with_seg: - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][iou_key] == pytest.approx( - reference_graph.edges[edge][iou_key], abs=0.01 - ) assert_array_almost_equal(tracks.segmentation, reference_seg) inverse = action.inverse() - assert set(tracks.graph.edges()) == set() + + assert set(tracks.graph.edge_ids()) == set() if tracks.segmentation is not None: assert_array_almost_equal(tracks.segmentation, reference_seg) inverse.inverse() - assert set(tracks.graph.nodes()) == set(reference_graph.nodes()) - assert set(tracks.graph.edges()) == set(reference_graph.edges()) + assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert set(tracks.graph.edge_ids()) == set(reference_graph.edge_ids()) + assert sorted(tracks.graph.edge_list()) == sorted(reference_graph.edge_list()) + assert_frame_equal( + tracks.graph.edge_attrs(), + reference_graph.edge_attrs(), + check_row_order=False, + check_column_order=False, + ) if with_seg: - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][iou_key] == pytest.approx( - reference_graph.edges[edge][iou_key], abs=0.01 - ) assert_array_almost_equal(tracks.segmentation, reference_seg) @@ -103,7 +115,7 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg): ), } for key, feature in custom_features.items(): - tracks.features[key] = feature + tracks.add_feature(key, feature) # Define custom edge attributes custom_attrs = { @@ -113,13 +125,13 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg): } # Add an edge with custom attributes - edge = (1, 2) + edge = (1, 5) action = AddEdge(tracks, edge, attributes=custom_attrs) # Verify all attributes are present after adding assert tracks.graph.has_edge(*edge) for key, value in custom_attrs.items(): - assert tracks.graph.edges[edge][key] == value, ( + assert td_get_single_attr_from_edge(tracks.graph, edge, key) == value, ( f"Attribute {key} not preserved after add" ) @@ -133,6 +145,6 @@ def test_custom_edge_attributes_preserved(get_tracks, ndim, with_seg): # Verify all custom attributes are still present after re-adding for key, value in custom_attrs.items(): - assert tracks.graph.edges[edge][key] == value, ( + assert td_get_single_attr_from_edge(tracks.graph, edge, key) == value, ( f"Attribute {key} not preserved after delete/re-add cycle" ) diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py index fe47ce01..8fcb87b6 100644 --- a/tests/actions/test_add_delete_nodes.py +++ b/tests/actions/test_add_delete_nodes.py @@ -1,14 +1,20 @@ -import copy - -import networkx as nx import numpy as np import pytest -from numpy.testing import assert_array_almost_equal +from numpy.testing import assert_array_almost_equal, assert_array_equal +from polars.testing import assert_frame_equal +from tracksdata.array import GraphArrayView from funtracks.actions import ( ActionGroup, AddNode, ) +from funtracks.utils.tracksdata_utils import ( + assert_node_attrs_equal_with_masks, + create_empty_graphview_graph, + pixels_to_td_mask, +) + +from ..conftest import make_2d_disk_mask, make_3d_sphere_mask @pytest.mark.parametrize("ndim", [3, 4]) @@ -17,45 +23,99 @@ def test_add_delete_nodes(get_tracks, ndim, with_seg): # Get a tracks instance tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) reference_graph = tracks.graph - reference_seg = copy.deepcopy(tracks.segmentation) + reference_seg = np.asarray(tracks.segmentation).copy() if with_seg else None # Start with an empty Tracks - empty_graph = nx.DiGraph() + node_attributes = [ + tracks.features.time_key, + tracks.features.tracklet_key, + tracks.features.position_key, + ] + edge_attributes = ["iou"] if with_seg else [] + empty_graph = create_empty_graphview_graph( + node_attributes=node_attributes + (["area", "bbox", "mask"] if with_seg else []), + edge_attributes=edge_attributes, + ndim=ndim, + ) empty_seg = np.zeros_like(tracks.segmentation) if with_seg else None tracks.graph = empty_graph - if with_seg: - tracks.segmentation = empty_seg + segmentation_shape = (5, 100, 100) if ndim == 3 else (5, 100, 100, 100) + tracks.segmentation = ( + GraphArrayView( + graph=tracks.graph, shape=segmentation_shape, attr_key="node_id", offset=0 + ) + if with_seg + else None + ) + + # add all the nodes from graph_2d/seg_2d + nodes = list(reference_graph.node_ids()) - nodes = list(reference_graph.nodes()) actions = [] for node in nodes: - pixels = np.nonzero(reference_seg == node) if with_seg else None - actions.append( - AddNode(tracks, node, dict(reference_graph.nodes[node]), pixels=pixels) - ) + if with_seg: + pixels = np.nonzero(reference_seg == node) + mask = pixels_to_td_mask(pixels, ndim=ndim) + else: + mask = None + + attrs = {} + attrs[tracks.features.time_key] = reference_graph[node][tracks.features.time_key] + if tracks.features.position_key == "pos": + attrs[tracks.features.position_key] = reference_graph[node][ + tracks.features.position_key + ].to_list() + else: + attrs[tracks.features.position_key] = reference_graph[node][ + tracks.features.position_key + ] + attrs[tracks.features.tracklet_key] = reference_graph[node][ + tracks.features.tracklet_key + ] + if with_seg: + attrs["bbox"] = reference_graph[node]["bbox"] + attrs["mask"] = reference_graph[node]["mask"] + + actions.append(AddNode(tracks, node, attributes=attrs, mask=mask)) action = ActionGroup(tracks=tracks, actions=actions) - assert set(tracks.graph.nodes()) == set(reference_graph.nodes()) - for node, data in tracks.graph.nodes(data=True): - reference_data = reference_graph.nodes[node] - assert data == reference_data + assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + data_tracks = tracks.graph.node_attrs() + data_reference = reference_graph.node_attrs() if with_seg: assert_array_almost_equal(tracks.segmentation, reference_seg) + assert_node_attrs_equal_with_masks(data_tracks, data_reference) + else: + assert_frame_equal( + data_reference, # .drop(["mask", "bbox", "area"]), + data_tracks, # .drop(["mask", "bbox", "area"]), + check_column_order=False, + check_row_order=False, + check_dtypes=False, + ) # Invert the action to delete all the nodes del_nodes = action.inverse() - assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) + assert set(tracks.graph.node_ids()) == set(empty_graph.node_ids()) if with_seg: assert_array_almost_equal(tracks.segmentation, empty_seg) # Re-invert the action to add back all the nodes and their attributes del_nodes.inverse() - assert set(tracks.graph.nodes()) == set(reference_graph.nodes()) - for node, data in tracks.graph.nodes(data=True): - reference_data = copy.deepcopy(reference_graph.nodes[node]) - assert data == reference_data + assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + data_tracks = tracks.graph.node_attrs() + data_reference = reference_graph.node_attrs() if with_seg: assert_array_almost_equal(tracks.segmentation, reference_seg) + assert_node_attrs_equal_with_masks(data_tracks, data_reference) + else: + assert_frame_equal( + data_reference, # .drop(["mask", "bbox", "area"]), + data_tracks, # .drop(["mask", "bbox", "area"]), + check_column_order=False, + check_row_order=False, + check_dtypes=False, + ) def test_add_node_missing_time(get_tracks): @@ -114,7 +174,7 @@ def test_custom_attributes_preserved(get_tracks, ndim, with_seg): ), } for key, feature in custom_features.items(): - tracks.features[key] = feature + tracks.add_feature(key, feature) # Define attributes including custom ones custom_attrs = { @@ -129,41 +189,52 @@ def test_custom_attributes_preserved(get_tracks, ndim, with_seg): # Create segmentation if needed if with_seg: - from conftest import sphere - from skimage.draw import disk - if ndim == 3: - rr, cc = disk(center=(50, 50), radius=5, shape=(100, 100)) - pixels = (np.array([2]), rr, cc) + # Create 2D mask centered at (50, 50) with radius 5 + mask = make_2d_disk_mask(center=(50, 50), radius=5) else: - mask = sphere(center=(50, 50, 50), radius=5, shape=(100, 100, 100)) # Create proper 4D pixel coordinates (t, z, y, x) - pixels = (np.array([2]), *np.nonzero(mask)) + mask = make_3d_sphere_mask(center=(50, 50, 50), radius=5) + custom_attrs["mask"] = mask + custom_attrs["bbox"] = mask.bbox custom_attrs.pop("pos") # pos will be computed from segmentation else: - pixels = None + mask = None # Add a node with custom attributes node_id = 100 - action = AddNode(tracks, node_id, custom_attrs, pixels=pixels) - + action = AddNode(tracks, node_id, custom_attrs.copy(), mask=mask) # Verify all attributes are present after adding assert tracks.graph.has_node(node_id) for key, value in custom_attrs.items(): - assert tracks.graph.nodes[node_id][key] == value, ( - f"Attribute {key} not preserved after add" - ) + if key == "pos": + assert_array_almost_equal(tracks.graph[node_id][key], np.array(value)) + elif key == "mask": + continue + elif key == "bbox": + assert_array_equal(np.asarray(tracks.graph[node_id][key]), value) + else: + assert tracks.graph[node_id][key] == value, ( + f"Attribute {key} not preserved after add" + ) # Delete the node delete_action = action.inverse() - assert not tracks.graph.has_node(node_id) + assert node_id not in tracks.graph.node_ids() # Re-add the node by inverting the delete delete_action.inverse() - assert tracks.graph.has_node(node_id) + assert node_id in tracks.graph.node_ids() # Verify all custom attributes are still present after re-adding for key, value in custom_attrs.items(): - assert tracks.graph.nodes[node_id][key] == value, ( - f"Attribute {key} not preserved after delete/re-add cycle" - ) + if key == "pos": + assert_array_almost_equal(tracks.graph[node_id][key], np.array(value)) + elif key == "mask": + continue + elif key == "bbox": + assert_array_equal(np.asarray(tracks.graph[node_id][key]), value) + else: + assert tracks.graph[node_id][key] == value, ( + f"Attribute {key} not preserved after delete/re-add cycle" + ) diff --git a/tests/actions/test_update_node_attrs.py b/tests/actions/test_update_node_attrs.py index b3ead7ed..4f2dfb30 100644 --- a/tests/actions/test_update_node_attrs.py +++ b/tests/actions/test_update_node_attrs.py @@ -3,19 +3,29 @@ from funtracks.actions import ( UpdateNodeAttrs, ) +from funtracks.features import Feature @pytest.mark.parametrize("ndim", [3, 4]) def test_update_node_attrs(get_tracks, ndim): tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) node = 1 - new_attr = {"score": 1.0} - action = UpdateNodeAttrs(tracks, node, new_attr) + new_feature = Feature( + feature_type="node", + value_type="float", + num_values=1, + display_name="Score", + required=False, + default_value=None, + ) + tracks.add_feature("score", new_feature) + + action = UpdateNodeAttrs(tracks, node, {"score": 1.0}) assert tracks.get_node_attr(node, "score") == 1.0 inverse = action.inverse() - assert tracks.get_node_attr(node, "score") is None + assert tracks.get_node_attr(node, "score") == -1.0 inverse.inverse() assert tracks.get_node_attr(node, "score") == 1.0 diff --git a/tests/actions/test_update_node_segs.py b/tests/actions/test_update_node_segs.py index 97cc1629..a7a8bf3f 100644 --- a/tests/actions/test_update_node_segs.py +++ b/tests/actions/test_update_node_segs.py @@ -1,49 +1,66 @@ -import copy - import numpy as np import pytest from numpy.testing import assert_array_almost_equal +from polars.testing import assert_series_equal -from funtracks.actions import ( - UpdateNodeSeg, -) +from funtracks.actions import UpdateNodeSeg +from funtracks.utils.tracksdata_utils import pixels_to_td_mask @pytest.mark.parametrize("ndim", [3, 4]) def test_update_node_segs(get_tracks, ndim): # Get tracks with segmentation tracks = get_tracks(ndim=ndim, with_seg=True, is_solution=True) - reference_graph = copy.deepcopy(tracks.graph) + reference_graph = tracks.graph.detach().filter().subgraph() + + node = 1 + time = tracks.get_time(node) + + # Populate the cache by accessing segmentation at the node's time + # This ensures _update_segmentation_cache will test the cache invalidation logic + _ = np.asarray(tracks.segmentation[time]) + + # Verify cache is populated + assert time in tracks.segmentation._cache._store - original_seg = tracks.segmentation.copy() - original_area = tracks.graph.nodes[1]["area"] - original_pos = tracks.graph.nodes[1]["pos"] + original_seg = np.asarray(tracks.segmentation).copy() + original_area = tracks.graph[1]["area"] + original_pos = tracks.graph[1]["pos"] # Add a couple pixels to the first node - new_seg = tracks.segmentation.copy() + new_seg = np.asarray(tracks.segmentation).copy() if ndim == 3: - new_seg[0][0][0] = 1 # 2D spatial + new_seg[time][0][0] = node # Use node time and node ID else: - new_seg[0][0][0][0] = 1 # 3D spatial - node = 1 + new_seg[time][0][0][0] = node # Use node time and node ID pixels = np.nonzero(original_seg != new_seg) - action = UpdateNodeSeg(tracks, node, pixels=pixels, added=True) + mask = pixels_to_td_mask(pixels, ndim=ndim) + + action = UpdateNodeSeg(tracks, node, mask=mask, added=True) - assert set(tracks.graph.nodes()) == set(reference_graph.nodes()) - assert tracks.graph.nodes[1]["area"] == original_area + 1 - assert tracks.graph.nodes[1]["pos"] != original_pos + assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert tracks.graph[1]["area"] == original_area + 1 + assert not np.allclose(tracks.graph[1]["pos"], original_pos) assert_array_almost_equal(tracks.segmentation, new_seg) + # Re-populate cache for inverse action test + _ = np.asarray(tracks.segmentation[time]) + inverse = action.inverse() - assert set(tracks.graph.nodes()) == set(reference_graph.nodes()) - for node, data in tracks.graph.nodes(data=True): - assert data == reference_graph.nodes[node] + assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert_series_equal( + reference_graph[1]["pos"], + tracks.graph[1]["pos"], + ) assert_array_almost_equal(tracks.segmentation, original_seg) + # Re-populate cache for second inverse test + _ = np.asarray(tracks.segmentation[time]) + inverse.inverse() - assert set(tracks.graph.nodes()) == set(reference_graph.nodes()) - assert tracks.graph.nodes[1]["area"] == original_area + 1 - assert tracks.graph.nodes[1]["pos"] != original_pos + assert set(tracks.graph.node_ids()) == set(reference_graph.node_ids()) + assert tracks.graph[1]["area"] == original_area + 1 + assert not np.allclose(tracks.graph[1]["pos"], original_pos) assert_array_almost_equal(tracks.segmentation, new_seg) diff --git a/tests/annotators/__init__.py b/tests/annotators/__init__.py new file mode 100644 index 00000000..cb930077 --- /dev/null +++ b/tests/annotators/__init__.py @@ -0,0 +1,2 @@ +# This file makes the tests/annotators directory a Python package +# to support relative imports diff --git a/tests/annotators/test_annotator_registry.py b/tests/annotators/test_annotator_registry.py index 46f5be6d..f60db7de 100644 --- a/tests/annotators/test_annotator_registry.py +++ b/tests/annotators/test_annotator_registry.py @@ -6,10 +6,16 @@ track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} -def test_annotator_registry_init_with_segmentation(graph_clean, segmentation_2d): +def test_annotator_registry_init_with_segmentation( + graph_2d_with_segmentation, +): """Test AnnotatorRegistry initializes regionprops and edge annotators with segmentation.""" - tracks = Tracks(graph_clean, segmentation=segmentation_2d, ndim=3, **track_attrs) + tracks = Tracks( + graph_2d_with_segmentation, + ndim=3, + **track_attrs, + ) annotator_types = [type(ann) for ann in tracks.annotators] assert RegionpropsAnnotator in annotator_types @@ -19,7 +25,7 @@ def test_annotator_registry_init_with_segmentation(graph_clean, segmentation_2d) def test_annotator_registry_init_without_segmentation(graph_2d_with_position): """Test AnnotatorRegistry doesn't create annotators without segmentation.""" - tracks = Tracks(graph_2d_with_position, segmentation=None, ndim=3, **track_attrs) + tracks = Tracks(graph_2d_with_position, ndim=3, **track_attrs) annotator_types = [type(ann) for ann in tracks.annotators] assert RegionpropsAnnotator not in annotator_types @@ -27,11 +33,15 @@ def test_annotator_registry_init_without_segmentation(graph_2d_with_position): assert TrackAnnotator not in annotator_types -def test_annotator_registry_init_solution_tracks(graph_clean, segmentation_2d): +def test_annotator_registry_init_solution_tracks( + graph_2d_with_segmentation, +): """Test AnnotatorRegistry creates all annotators for SolutionTracks with segmentation.""" tracks = SolutionTracks( - graph_clean, segmentation=segmentation_2d, ndim=3, **track_attrs + graph_2d_with_segmentation, + ndim=3, + **track_attrs, ) annotator_types = [type(ann) for ann in tracks.annotators] @@ -40,18 +50,22 @@ def test_annotator_registry_init_solution_tracks(graph_clean, segmentation_2d): assert TrackAnnotator in annotator_types -def test_enable_disable_features(graph_clean, segmentation_2d): - tracks = Tracks(graph_clean, segmentation=segmentation_2d, ndim=3, **track_attrs) +def test_enable_disable_features(graph_2d_with_segmentation): + tracks = Tracks( + graph_2d_with_segmentation, + ndim=3, + **track_attrs, + ) - nodes = list(tracks.graph.nodes()) - edges = list(tracks.graph.edges()) + nodes = list(tracks.graph.node_ids()) + edges = list(tracks.graph.edge_ids()) # Core features (time, pos, area) should be in tracks.features and computed assert "pos" in tracks.features assert "t" in tracks.features assert "area" in tracks.features # Core feature for backward compatibility - assert tracks.graph.nodes[nodes[0]].get("pos") is not None - assert tracks.graph.nodes[nodes[0]].get("area") is not None + assert tracks.graph[nodes[0]]["pos"] is not None + assert tracks.graph[nodes[0]]["area"] is not None # Other features should NOT be in tracks.features initially assert "iou" not in tracks.features @@ -65,9 +79,9 @@ def test_enable_disable_features(graph_clean, segmentation_2d): assert "circularity" in tracks.features # Verify values are actually computed on the graph - assert tracks.graph.nodes[nodes[0]].get("circularity") is not None + assert tracks.graph[nodes[0]]["circularity"] is not None if edges: - assert tracks.graph.edges[edges[0]].get("iou") is not None + assert None not in tracks.graph.edge_attrs()["iou"].to_list() # Disable one feature tracks.disable_features(["area"]) @@ -78,8 +92,8 @@ def test_enable_disable_features(graph_clean, segmentation_2d): assert "iou" in tracks.features assert "circularity" in tracks.features - # Values still exist on the graph (disabling doesn't erase computed values) - assert tracks.graph.nodes[nodes[0]].get("area") is not None + # Values no longer exist in the graph for tracksdata + # assert tracks.graph[1]["area"] is not None # Disable the remaining enabled features tracks.disable_features(["pos", "iou", "circularity"]) @@ -88,10 +102,12 @@ def test_enable_disable_features(graph_clean, segmentation_2d): assert "circularity" not in tracks.features -def test_get_available_features(graph_clean, segmentation_2d): +def test_get_available_features(graph_2d_with_segmentation): """Test get_available_features returns all features from all annotators.""" tracks = SolutionTracks( - graph_clean, segmentation=segmentation_2d, ndim=3, **track_attrs + graph_2d_with_segmentation, + ndim=3, + **track_attrs, ) available = tracks.get_available_features() @@ -103,25 +119,29 @@ def test_get_available_features(graph_clean, segmentation_2d): assert "track_id" in available # tracks -def test_enable_nonexistent_feature(graph_clean, segmentation_2d): +def test_enable_nonexistent_feature(graph_clean): """Test enabling a nonexistent feature raises KeyError.""" - tracks = Tracks(graph_clean, segmentation=segmentation_2d, ndim=3, **track_attrs) + tracks = Tracks(graph_clean, ndim=3, **track_attrs) with pytest.raises(KeyError, match="Features not available"): tracks.enable_features(["nonexistent"]) -def test_disable_nonexistent_feature(graph_clean, segmentation_2d): +def test_disable_nonexistent_feature(graph_clean): """Test disabling a nonexistent feature raises KeyError.""" - tracks = Tracks(graph_clean, segmentation=segmentation_2d, ndim=3, **track_attrs) + tracks = Tracks(graph_clean, ndim=3, **track_attrs) with pytest.raises(KeyError, match="Features not available"): tracks.disable_features(["nonexistent"]) -def test_compute_strict_validation(graph_clean, segmentation_2d): +def test_compute_strict_validation(graph_2d_with_segmentation): """Test that compute() strictly validates feature keys.""" - tracks = Tracks(graph_clean, segmentation=segmentation_2d, ndim=3, **track_attrs) + tracks = Tracks( + graph_2d_with_segmentation, + ndim=3, + **track_attrs, + ) # Get the RegionpropsAnnotator from the annotators rp_ann = next( diff --git a/tests/annotators/test_edge_annotator.py b/tests/annotators/test_edge_annotator.py index 2999b646..a567c43e 100644 --- a/tests/annotators/test_edge_annotator.py +++ b/tests/annotators/test_edge_annotator.py @@ -3,17 +3,24 @@ from funtracks.actions import UpdateNodeSeg, UpdateTrackID from funtracks.annotators import EdgeAnnotator from funtracks.data_model import SolutionTracks, Tracks +from funtracks.utils.tracksdata_utils import ( + pixels_to_td_mask, + td_get_single_attr_from_edge, +) track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} @pytest.mark.parametrize("ndim", [3, 4]) class TestEdgeAnnotator: - def test_init(self, get_graph, get_segmentation, ndim): + def test_init(self, get_graph, ndim): # Start with clean graph, no existing features - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = Tracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + graph = get_graph(ndim, with_features="segmentation") + tracks = Tracks( + graph, + ndim=ndim, + **track_attrs, + ) ann = EdgeAnnotator(tracks) # Features start disabled by default assert len(ann.all_features) == 1 @@ -22,10 +29,13 @@ def test_init(self, get_graph, get_segmentation, ndim): ann.activate_features(list(ann.all_features.keys())) assert len(ann.features) == 1 - def test_compute_all(self, get_graph, get_segmentation, ndim): - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = Tracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + def test_compute_all(self, get_graph, ndim): + graph = get_graph(ndim, with_features="segmentation") + tracks = Tracks( + graph, + ndim=ndim, + **track_attrs, + ) ann = EdgeAnnotator(tracks) # Enable features ann.activate_features(list(ann.all_features.keys())) @@ -33,14 +43,16 @@ def test_compute_all(self, get_graph, get_segmentation, ndim): # Compute values ann.compute() - for edge in tracks.edges(): - for key in all_features: - assert key in tracks.graph.edges[edge] - - def test_update_all(self, get_graph, get_segmentation, ndim) -> None: - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = Tracks(graph, segmentation=seg, ndim=ndim, **track_attrs) # type: ignore + for key in all_features: + assert key in tracks.graph.edge_attr_keys() + + def test_update_all(self, get_graph, ndim) -> None: + graph = get_graph(ndim, with_features="segmentation") + tracks = Tracks( + graph, + ndim=ndim, + **track_attrs, + ) # type: ignore # Get the EdgeAnnotator from the registry ann = next(ann for ann in tracks.annotators if isinstance(ann, EdgeAnnotator)) # Enable features through tracks (which updates the registry) @@ -56,24 +68,26 @@ def test_update_all(self, get_graph, get_segmentation, ndim) -> None: expected_iou = pytest.approx(0.0, abs=0.001) # Use UpdateNodeSeg action to modify segmentation and update edge - UpdateNodeSeg(tracks, node_id, pixels_to_remove, added=False) + mask_to_remove = pixels_to_td_mask(pixels_to_remove, ndim=ndim) + UpdateNodeSeg(tracks, node_id, mask_to_remove, added=False) assert tracks.get_edge_attr(edge_id, "iou", required=True) == expected_iou # segmentation is fully erased and you try to update node_id = 1 - pixels = tracks.get_pixels(node_id) - assert pixels is not None - with pytest.warns( - match="Cannot find label 1 in frame .*: updating edge IOU value to 0" - ): - UpdateNodeSeg(tracks, node_id, pixels, added=False) - - assert tracks.graph.edges[edge_id]["iou"] == 0 - - def test_add_remove_feature(self, get_graph, get_segmentation, ndim): - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = Tracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + mask = tracks.get_mask(node_id) + assert mask is not None + with pytest.warns(match="Cannot find label 1 in frame .*"): + UpdateNodeSeg(tracks, node_id, mask, added=False) + + assert td_get_single_attr_from_edge(tracks.graph, edge_id, "iou") == 0 + + def test_add_remove_feature(self, get_graph, ndim): + graph = get_graph(ndim, with_features="segmentation") + tracks = Tracks( + graph, + ndim=ndim, + **track_attrs, + ) # Get the EdgeAnnotator from the registry ann = next(ann for ann in tracks.annotators if isinstance(ann, EdgeAnnotator)) # Enable features through tracks @@ -82,7 +96,6 @@ def test_add_remove_feature(self, get_graph, get_segmentation, ndim): node_id = 3 edge_id = (1, 3) to_remove_key = next(iter(ann.features)) - orig_iou = tracks.get_edge_attr(edge_id, to_remove_key, required=True) # remove the IOU from computation (tracks level) tracks.disable_features([to_remove_key]) @@ -90,19 +103,19 @@ def test_add_remove_feature(self, get_graph, get_segmentation, ndim): orig_pixels = tracks.get_pixels(node_id) assert orig_pixels is not None pixels_to_remove = tuple(orig_pixels[d][1:] for d in range(len(orig_pixels))) - tracks.set_pixels(pixels_to_remove, 0) # Compute at tracks level - this should not update the removed feature for a in tracks.annotators: if isinstance(a, EdgeAnnotator): a.compute() - # IoU was computed before removal, so value is still there - assert tracks.get_edge_attr(edge_id, to_remove_key, required=True) == orig_iou + # IoU feature was deleted, so IoU is no longer present on the graph + # assert tracks.get_edge_attr(edge_id, to_remove_key, required=True) == orig_iou # add it back in tracks.enable_features([to_remove_key]) # Use UpdateNodeSeg action to modify segmentation and update edge - UpdateNodeSeg(tracks, node_id, pixels_to_remove, added=False) + mask_to_remove = pixels_to_td_mask(pixels_to_remove, ndim=ndim) + UpdateNodeSeg(tracks, node_id, mask_to_remove, added=False) new_iou = pytest.approx(0.0, abs=0.001) # the feature is now updated assert tracks.get_edge_attr(edge_id, to_remove_key, required=True) == new_iou @@ -110,32 +123,26 @@ def test_add_remove_feature(self, get_graph, get_segmentation, ndim): def test_missing_seg(self, get_graph, ndim) -> None: """Test that EdgeAnnotator gracefully handles missing segmentation.""" graph = get_graph(ndim, with_features="clean") - tracks = Tracks(graph, segmentation=None, ndim=ndim, **track_attrs) # type: ignore + tracks = Tracks(graph, ndim=ndim, **track_attrs) # type: ignore ann = EdgeAnnotator(tracks) assert len(ann.features) == 0 # Should not raise an error, just return silently ann.compute() # No error expected - def test_ignores_irrelevant_actions(self, get_graph, get_segmentation, ndim): + def test_ignores_irrelevant_actions(self, get_graph, ndim): """Test that EdgeAnnotator ignores actions that don't affect edges.""" - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = SolutionTracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + graph = get_graph(ndim, with_features="segmentation") + tracks = SolutionTracks( + graph, + ndim=ndim, + **track_attrs, + ) tracks.enable_features(["iou", track_attrs["tracklet_attr"]]) - edge_id = (1, 3) - initial_iou = tracks.graph.edges[edge_id]["iou"] - - # Manually modify segmentation (without triggering an action) - # Remove half the pixels from node 3 (target of the edge) node_id = 3 - orig_pixels = tracks.get_pixels(node_id) - assert orig_pixels is not None - pixels_to_remove = tuple( - orig_pixels[d][: len(orig_pixels[d]) // 2] for d in range(len(orig_pixels)) - ) - tracks.set_pixels(pixels_to_remove, 0) + edge_id = (1, 3) + initial_iou = td_get_single_attr_from_edge(tracks.graph, edge_id, "iou") # If we recomputed IoU now, it would be different # But we won't - we'll just call UpdateTrackID on node 1 @@ -149,6 +156,6 @@ def test_ignores_irrelevant_actions(self, get_graph, get_segmentation, ndim): UpdateTrackID(tracks, node_id, new_track_id) # IoU should remain unchanged (no recomputation happened despite seg change) - assert tracks.graph.edges[edge_id]["iou"] == initial_iou + assert td_get_single_attr_from_edge(tracks.graph, edge_id, "iou") == initial_iou # But track_id should be updated assert tracks.get_track_id(node_id) == new_track_id diff --git a/tests/annotators/test_graph_annotator.py b/tests/annotators/test_graph_annotator.py index 34dd4404..77a294b2 100644 --- a/tests/annotators/test_graph_annotator.py +++ b/tests/annotators/test_graph_annotator.py @@ -8,8 +8,8 @@ track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} -def test_base_graph_annotator(graph_clean, segmentation_2d): - tracks = Tracks(graph_clean, segmentation=segmentation_2d, **track_attrs) +def test_base_graph_annotator(graph_2d_with_segmentation): + tracks = Tracks(graph_2d_with_segmentation, **track_attrs) ann = GraphAnnotator(tracks, {}) assert len(ann.features) == 0 diff --git a/tests/annotators/test_regionprops_annotator.py b/tests/annotators/test_regionprops_annotator.py index 807cee59..adb04076 100644 --- a/tests/annotators/test_regionprops_annotator.py +++ b/tests/annotators/test_regionprops_annotator.py @@ -1,18 +1,23 @@ +import numpy as np import pytest from funtracks.actions import UpdateNodeSeg, UpdateTrackID from funtracks.annotators import RegionpropsAnnotator from funtracks.data_model import SolutionTracks, Tracks +from funtracks.utils.tracksdata_utils import pixels_to_td_mask track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} @pytest.mark.parametrize("ndim", [3, 4]) class TestRegionpropsAnnotator: - def test_init(self, get_graph, get_segmentation, ndim): - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = Tracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + def test_init(self, get_graph, ndim): + graph = get_graph(ndim, with_features="segmentation") + tracks = Tracks( + graph, + ndim=ndim, + **track_attrs, + ) rp_ann = RegionpropsAnnotator(tracks) # Features start disabled by default assert len(rp_ann.all_features) == 5 @@ -23,24 +28,32 @@ def test_init(self, get_graph, get_segmentation, ndim): len(rp_ann.features) == 5 ) # pos, area, ellipse_axis_radii, circularity, perimeter - def test_compute_all(self, get_graph, get_segmentation, ndim): - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = Tracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + def test_compute_all(self, get_graph, ndim): + graph = get_graph(ndim, with_features="segmentation") + tracks = Tracks( + graph, + ndim=ndim, + **track_attrs, + ) rp_ann = RegionpropsAnnotator(tracks) # Enable features rp_ann.activate_features(list(rp_ann.all_features.keys())) # Compute values rp_ann.compute() - for node in tracks.nodes(): - for key in rp_ann.features: - assert key in tracks.graph.nodes[node] - - def test_update_all(self, get_graph, get_segmentation, ndim): - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = Tracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + for key in rp_ann.features: + assert key in tracks.graph.node_attr_keys() + for node_id in tracks.graph.node_ids(): + value = tracks.graph[node_id][key] + assert value is not None + + def test_update_all(self, get_graph, ndim): + graph = get_graph(ndim, with_features="segmentation") + tracks = Tracks( + graph, + ndim=ndim, + **track_attrs, + ) node_id = 3 # Get the RegionpropsAnnotator from the registry @@ -53,29 +66,38 @@ def test_update_all(self, get_graph, get_segmentation, ndim): orig_pixels = tracks.get_pixels(node_id) # remove all but one pixel pixels_to_remove = tuple(orig_pixels[d][1:] for d in range(len(orig_pixels))) + mask_to_remove = pixels_to_td_mask(pixels_to_remove, ndim=ndim) expected_area = 1 # Use UpdateNodeSeg action to modify segmentation and update features - UpdateNodeSeg(tracks, node_id, pixels_to_remove, added=False) + UpdateNodeSeg(tracks, node_id, mask_to_remove, added=False) assert tracks.get_node_attr(node_id, "area") == expected_area for key in rp_ann.features: - assert key in tracks.graph.nodes[node_id] + assert key in tracks.graph.node_attr_keys() # segmentation is fully erased and you try to update node_id = 1 - pixels = tracks.get_pixels(node_id) + mask = tracks.get_mask(node_id) with pytest.warns( match="Cannot find label 1 in frame .*: updating regionprops values to None" ): - UpdateNodeSeg(tracks, node_id, pixels, added=False) - + UpdateNodeSeg(tracks, node_id, mask, added=False) + # all regionprops features should be the defaults, because seg doesn't exist for key in rp_ann.features: - assert tracks.graph.nodes[node_id][key] is None - - def test_add_remove_feature(self, get_graph, get_segmentation, ndim): - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = Tracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + actual = tracks.graph[node_id][key] + expected = tracks.graph._node_attr_schemas()[key].default_value + # Convert to numpy arrays for comparison (handles both scalar and array types) + actual_np = np.asarray(actual) + expected_np = np.asarray(expected) + assert np.array_equal(actual_np, expected_np) + + def test_add_remove_feature(self, get_graph, ndim): + graph = get_graph(ndim, with_features="segmentation") + tracks = Tracks( + graph, + ndim=ndim, + **track_attrs, + ) # Get the RegionpropsAnnotator from the registry rp_ann = next( ann for ann in tracks.annotators if isinstance(ann, RegionpropsAnnotator) @@ -85,13 +107,10 @@ def test_add_remove_feature(self, get_graph, get_segmentation, ndim): rp_ann.deactivate_features([to_remove_key]) # Clear existing area attributes from graph (from fixture) - for node in tracks.nodes(): - if to_remove_key in tracks.graph.nodes[node]: - del tracks.graph.nodes[node][to_remove_key] + graph.remove_node_attr_key(to_remove_key) rp_ann.compute() - for node in tracks.nodes(): - assert to_remove_key not in tracks.graph.nodes[node] + assert to_remove_key not in tracks.graph.node_attr_keys() # add it back in rp_ann.activate_features([to_remove_key]) @@ -101,33 +120,34 @@ def test_add_remove_feature(self, get_graph, get_segmentation, ndim): # remove all but one pixel node_id = 3 - prev_value = tracks.get_node_attr(node_id, second_remove_key) orig_pixels = tracks.get_pixels(node_id) assert orig_pixels is not None pixels_to_remove = tuple(orig_pixels[d][1:] for d in range(len(orig_pixels))) + mask_to_remove = pixels_to_td_mask(pixels_to_remove, ndim=ndim) # Use UpdateNodeSeg action to modify segmentation and update features - UpdateNodeSeg(tracks, node_id, pixels_to_remove, added=False) - # the new one we removed is not updated - assert tracks.get_node_attr(node_id, second_remove_key) == prev_value + UpdateNodeSeg(tracks, node_id, mask_to_remove, added=False) # the one we added back in is now present assert tracks.get_node_attr(node_id, to_remove_key) is not None def test_missing_seg(self, get_graph, ndim): """Test that RegionpropsAnnotator gracefully handles missing segmentation.""" graph = get_graph(ndim, with_features="clean") - tracks = Tracks(graph, segmentation=None, ndim=ndim, **track_attrs) + tracks = Tracks(graph, ndim=ndim, **track_attrs) rp_ann = RegionpropsAnnotator(tracks) assert len(rp_ann.features) == 0 # Should not raise an error, just return silently rp_ann.compute() # No error expected - def test_ignores_irrelevant_actions(self, get_graph, get_segmentation, ndim): + def test_ignores_irrelevant_actions(self, get_graph, ndim): """Test that RegionpropsAnnotator ignores actions that don't affect segmentation. """ - graph = get_graph(ndim, with_features="clean") - seg = get_segmentation(ndim) - tracks = SolutionTracks(graph, segmentation=seg, ndim=ndim, **track_attrs) + graph = get_graph(ndim, with_features="segmentation") + tracks = SolutionTracks( + graph, + ndim=ndim, + **track_attrs, + ) tracks.enable_features(["area", "track_id"]) node_id = 1 @@ -137,10 +157,6 @@ def test_ignores_irrelevant_actions(self, get_graph, get_segmentation, ndim): # Remove half the pixels from node 1 orig_pixels = tracks.get_pixels(node_id) assert orig_pixels is not None - pixels_to_remove = tuple( - orig_pixels[d][: len(orig_pixels[d]) // 2] for d in range(len(orig_pixels)) - ) - tracks.set_pixels(pixels_to_remove, 0) # If we recomputed area now, it would be different # But we won't - we'll just call UpdateTrackID diff --git a/tests/annotators/test_track_annotator.py b/tests/annotators/test_track_annotator.py index 3d5d33c3..439cfbaa 100644 --- a/tests/annotators/test_track_annotator.py +++ b/tests/annotators/test_track_annotator.py @@ -2,6 +2,7 @@ from funtracks.actions import UpdateNodeSeg from funtracks.annotators import TrackAnnotator +from funtracks.utils.tracksdata_utils import pixels_to_td_mask @pytest.mark.parametrize("ndim", [3, 4]) @@ -26,16 +27,16 @@ def test_init(self, get_tracks, ndim, with_seg) -> None: def test_compute_all(self, get_tracks, ndim, with_seg) -> None: tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - ann = TrackAnnotator(tracks) + ann = TrackAnnotator(tracks, tracklet_key=tracks.features.tracklet_key) # Enable features ann.activate_features(list(ann.all_features.keys())) all_features = ann.features # Compute values ann.compute() - for node in tracks.nodes(): + for node in tracks.graph.node_ids(): for key in all_features: - assert key in tracks.graph.nodes[node] + assert tracks.graph[node][key] is not None lineages = [ [1, 2, 3, 4, 5], @@ -62,7 +63,7 @@ def test_compute_all(self, get_tracks, ndim, with_seg) -> None: def test_add_remove_feature(self, get_tracks, ndim, with_seg): tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=True) - ann = TrackAnnotator(tracks) + ann = TrackAnnotator(tracks, tracklet_key=tracks.features.tracklet_key) # Enable features ann.activate_features(list(ann.all_features.keys())) # compute the original tracklet and lineage ids @@ -70,7 +71,8 @@ def test_add_remove_feature(self, get_tracks, ndim, with_seg): # add an edge node_id = 6 edge_id = (4, 6) - tracks.graph.add_edge(*edge_id) + attrs = {"iou": 0, "solution": 1} if with_seg else {"solution": 1} + tracks.graph.add_edge(source_id=edge_id[0], target_id=edge_id[1], attrs=attrs) to_remove_key = ann.lineage_key orig_lin = tracks.get_node_attr(node_id, ann.lineage_key, required=True) orig_tra = tracks.get_node_attr(node_id, ann.tracklet_key, required=True) @@ -112,9 +114,10 @@ def test_ignores_irrelevant_actions(self, get_tracks, ndim, with_seg): orig_pixels = tracks.get_pixels(node_id) assert orig_pixels is not None pixels_to_remove = tuple(orig_pixels[d][1:] for d in range(len(orig_pixels))) + mask_to_remove = pixels_to_td_mask(pixels_to_remove, ndim=ndim) # Perform UpdateNodeSeg action - UpdateNodeSeg(tracks, node_id, pixels_to_remove, added=False) + UpdateNodeSeg(tracks, node_id, mask_to_remove, added=False) # Track ID should remain unchanged (no track update happened) assert tracks.get_track_id(node_id) == initial_track_id diff --git a/tests/conftest.py b/tests/conftest.py index 0cd926a5..ced38560 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,51 +1,135 @@ -import copy from collections.abc import Callable from typing import TYPE_CHECKING -import networkx as nx import numpy as np +import polars as pl import pytest +import tracksdata as td from skimage.draw import disk +from tracksdata.nodes._mask import Mask + +from funtracks.utils.tracksdata_utils import ( + create_empty_graphview_graph, +) if TYPE_CHECKING: from typing import Any - from numpy.typing import NDArray - from funtracks.data_model import SolutionTracks, Tracks # Feature list constants for consistent test usage -FEATURES_WITH_SEG = ["pos", "area", "iou"] +# WITH_SEG means segmentation stored as mask/bbox node attributes +FEATURES_WITH_SEG = ["pos", "area", "iou", "mask", "bbox"] FEATURES_NO_SEG = ["pos"] -SOLUTION_FEATURES_WITH_SEG = ["pos", "area", "iou", "track_id"] +SOLUTION_FEATURES_WITH_SEG = ["pos", "area", "iou", "track_id", "mask", "bbox"] SOLUTION_FEATURES_NO_SEG = ["pos", "track_id"] -@pytest.fixture -def segmentation_2d() -> "NDArray[np.int32]": - frame_shape = (100, 100) - total_shape = (5, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) - segmentation[0][rr, cc] = 1 +def make_2d_disk_mask(center=(50, 50), radius=20) -> Mask: + """Create a 2D disk mask with bounding box. + + Args: + center: Center coordinates (y, x) + radius: Radius of the disk + + Returns: + tracksdata Mask object with boolean mask and bbox + """ + radius_actual = radius - 1 + mask_shape = (2 * radius - 1, 2 * radius - 1) + rr, cc = disk(center=(radius_actual, radius_actual), radius=radius, shape=mask_shape) + mask_disk = np.zeros(mask_shape, dtype="bool") + mask_disk[rr, cc] = True + return Mask( + mask_disk, + bbox=np.array( + [ + center[0] - radius_actual, + center[1] - radius_actual, + center[0] + radius_actual + 1, + center[1] + radius_actual + 1, + ] + ), + ) + + +def make_3d_sphere_mask(center=(50, 50, 50), radius=20) -> Mask: + """Create a 3D sphere mask with bounding box. + + Args: + center: Center coordinates (z, y, x) + radius: Radius of the sphere + + Returns: + tracksdata Mask object with boolean mask and bbox + """ + mask_shape = (2 * radius + 1, 2 * radius + 1, 2 * radius + 1) + mask_sphere = sphere(center=(radius, radius, radius), radius=radius, shape=mask_shape) + return Mask( + mask_sphere, + bbox=np.array( + [ + center[0] - radius, + center[1] - radius, + center[2] - radius, + center[0] + radius + 1, + center[1] + radius + 1, + center[2] + radius + 1, + ] + ), + ) + - # make frame with two cells - # first cell centered at (20, 80) with label 2 - # second cell centered at (60, 45) with label 3 - rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - segmentation[1][rr, cc] = 2 - rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) - segmentation[1][rr, cc] = 3 +def make_2d_square_mask(start_corner=(0, 0), width=4) -> Mask: + """Create a 2D square mask with bounding box. + + Args: + start_corner: Top-left corner coordinates (y, x) + width: Width and height of the square + + Returns: + tracksdata Mask object with boolean mask and bbox + """ + mask_shape = (width, width) + mask_square = np.ones(mask_shape, dtype="bool") + return Mask( + mask_square, + bbox=np.array( + [ + start_corner[0], + start_corner[1], + start_corner[0] + width, + start_corner[1] + width, + ] + ), + ) - # continue track 3 with squares from 0 to 4 in x and y with label 3 - segmentation[2, 0:4, 0:4] = 4 - segmentation[4, 0:4, 0:4] = 5 - # unconnected node - segmentation[4, 96:100, 96:100] = 6 +def make_3d_cube_mask(start_corner=(0, 0, 0), width=4) -> Mask: + """Create a 3D cube mask with bounding box. - return segmentation + Args: + start_corner: Corner coordinates (z, y, x) + width: Width, height, and depth of the cube + + Returns: + tracksdata Mask object with boolean mask and bbox + """ + mask_shape = (width, width, width) + mask_cube = np.ones(mask_shape, dtype="bool") + return Mask( + mask_cube, + bbox=np.array( + [ + start_corner[0], + start_corner[1], + start_corner[2], + start_corner[0] + width, + start_corner[1] + width, + start_corner[2] + width, + ] + ), + ) def _make_graph( @@ -55,7 +139,9 @@ def _make_graph( with_track_id: bool = False, with_area: bool = False, with_iou: bool = False, -) -> nx.DiGraph: + with_masks: bool = False, + database: str | None = None, +) -> td.graph.GraphView: """Generate a test graph with configurable features. Args: @@ -64,11 +150,34 @@ def _make_graph( with_track_id: Include track_id attribute with_area: Include area attribute (requires with_pos=True) with_iou: Include iou edge attribute (requires with_area=True) + with_masks: Include mask and bbox node attributes + database: Database path for SQLGraph (if None, uses default) Returns: A graph with the requested features """ - graph = nx.DiGraph() + + node_attributes = [] + edge_attributes = [] + if with_pos: + node_attributes.append("pos") + if with_track_id: + node_attributes.append("track_id") + if with_area: + node_attributes.append("area") + if with_iou: + edge_attributes.append("iou") + if with_masks: + node_attributes.append(td.DEFAULT_ATTR_KEYS.MASK) + node_attributes.append(td.DEFAULT_ATTR_KEYS.BBOX) + + graph = create_empty_graphview_graph( + node_attributes=node_attributes, + edge_attributes=edge_attributes, + database=database, + position_attrs=["pos"] if with_pos else None, + ndim=ndim, + ) # Base node data (always has time) base_nodes = [ @@ -107,85 +216,145 @@ def _make_graph( # Track IDs track_ids = {1: 1, 2: 2, 3: 3, 4: 3, 5: 3, 6: 5} + # Mask data (matches segmentation structure) + segmentation_shape: tuple[int, ...] + if ndim == 3: # 2D spatial + masks = { + 1: make_2d_disk_mask(center=(50, 50), radius=20), + 2: make_2d_disk_mask(center=(20, 80), radius=10), + 3: make_2d_disk_mask(center=(60, 45), radius=15), + 4: make_2d_square_mask(start_corner=(0, 0), width=4), + 5: make_2d_square_mask(start_corner=(0, 0), width=4), + 6: make_2d_square_mask(start_corner=(96, 96), width=4), + } + segmentation_shape = (5, 100, 100) + else: # 3D spatial + masks = { + 1: make_3d_sphere_mask(center=(50, 50, 50), radius=20), + 2: make_3d_sphere_mask(center=(20, 50, 80), radius=10), + 3: make_3d_sphere_mask(center=(60, 50, 45), radius=15), + 4: make_3d_cube_mask(start_corner=(0, 0, 0), width=4), + 5: make_3d_cube_mask(start_corner=(0, 0, 0), width=4), + 6: make_3d_cube_mask(start_corner=(96, 96, 96), width=4), + } + segmentation_shape = (5, 100, 100, 100) + # Build nodes with requested features - nodes = [] + nodes_id_list = [] + nodes_attrs_list = [] for node_id, attrs in base_nodes: node_attrs: dict[str, Any] = dict(attrs) # Start with time + node_attrs["solution"] = 1 if with_pos: + # TODO: don't hardcode "pos" and other column names node_attrs["pos"] = positions[node_id] if with_track_id: node_attrs["track_id"] = track_ids[node_id] if with_area: - node_attrs["area"] = areas[node_id] - nodes.append((node_id, node_attrs)) - - edges = [(1, 2), (1, 3), (3, 4), (4, 5)] + node_attrs["area"] = float(areas[node_id]) + # I think this is necessary, to keep the dtype the same, + # in case the scale are not integers + if with_masks: + mask = masks[node_id] + node_attrs[td.DEFAULT_ATTR_KEYS.MASK] = mask + node_attrs[td.DEFAULT_ATTR_KEYS.BBOX] = mask.bbox + nodes_id_list.append(node_id) + nodes_attrs_list.append(node_attrs) + + edges = [ + {"source_id": 1, "target_id": 2, "solution": 1}, + {"source_id": 1, "target_id": 3, "solution": 1}, + {"source_id": 3, "target_id": 4, "solution": 1}, + {"source_id": 4, "target_id": 5, "solution": 1}, + ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) + graph.bulk_add_nodes(nodes=nodes_attrs_list, indices=nodes_id_list) + graph.bulk_add_edges(edges) + if with_masks: + graph.update_metadata(segmentation_shape=segmentation_shape) # Add IOUs to edges if requested if with_iou: for edge, iou in ious.items(): - if edge in graph.edges: - graph.edges[edge]["iou"] = iou + if graph.has_edge(edge[0], edge[1]): + edge_id = graph.edge_id(edge[0], edge[1]) + graph.update_edge_attrs(attrs={"iou": iou}, edge_ids=[edge_id]) return graph @pytest.fixture -def graph_clean() -> nx.DiGraph: +def graph_clean(tmp_path) -> td.graph.GraphView: """Base graph with only time - no positions or computed features.""" - return _make_graph(ndim=3) + db_path = str(tmp_path / "graph_clean.db") + return _make_graph(ndim=3, database=db_path) @pytest.fixture -def graph_2d_with_position() -> nx.DiGraph: +def graph_2d_with_position(tmp_path) -> td.graph.GraphView: """Graph with 2D positions - for Tracks without segmentation.""" - return _make_graph(ndim=3, with_pos=True) + db_path = str(tmp_path / "graph_2d_position.db") + return _make_graph(ndim=3, with_pos=True, database=db_path) @pytest.fixture -def graph_2d_with_track_id() -> nx.DiGraph: +def graph_2d_with_track_id(tmp_path) -> td.graph.GraphView: """Graph with 2D positions and track_id - for SolutionTracks without segmentation.""" - return _make_graph(ndim=3, with_pos=True, with_track_id=True) + db_path = str(tmp_path / "graph_2d_track_id.db") + return _make_graph(ndim=3, with_pos=True, with_track_id=True, database=db_path) @pytest.fixture -def graph_2d_with_computed_features() -> nx.DiGraph: - """Graph with all computed features - for SolutionTracks with segmentation.""" +def graph_2d_with_segmentation(tmp_path) -> td.graph.GraphView: + """Graph with segmentation (masks/bboxes) and all computed features.""" + db_path = str(tmp_path / "graph_2d_segmentation.db") return _make_graph( - ndim=3, with_pos=True, with_track_id=True, with_area=True, with_iou=True + ndim=3, + with_pos=True, + with_track_id=True, + with_area=True, + with_iou=True, + with_masks=True, + database=db_path, ) @pytest.fixture -def graph_3d_with_position() -> nx.DiGraph: +def graph_3d_with_position(tmp_path) -> td.graph.GraphView: """Graph with 3D positions - for Tracks without segmentation.""" - return _make_graph(ndim=4, with_pos=True) + db_path = str(tmp_path / "graph_3d_position.db") + return _make_graph(ndim=4, with_pos=True, database=db_path) @pytest.fixture -def graph_3d_with_track_id() -> nx.DiGraph: +def graph_3d_with_track_id(tmp_path) -> td.graph.GraphView: """Graph with 3D positions and track_id - for SolutionTracks without segmentation.""" - return _make_graph(ndim=4, with_pos=True, with_track_id=True) + db_path = str(tmp_path / "graph_3d_track_id.db") + return _make_graph(ndim=4, with_pos=True, with_track_id=True, database=db_path) @pytest.fixture -def graph_3d_with_computed_features() -> nx.DiGraph: - """Graph with all computed features - for SolutionTracks with segmentation.""" +def graph_3d_with_segmentation(tmp_path) -> td.graph.GraphView: + """Graph with segmentation (masks/bboxes) and all computed features.""" + db_path = str(tmp_path / "graph_3d_segmentation.db") return _make_graph( - ndim=4, with_pos=True, with_track_id=True, with_area=True, with_iou=True + ndim=4, + with_pos=True, + with_track_id=True, + with_area=True, + with_iou=True, + with_masks=True, + database=db_path, ) @pytest.fixture -def get_tracks(get_graph, get_segmentation) -> Callable[..., "Tracks | SolutionTracks"]: +def get_tracks(get_graph) -> Callable[..., "Tracks | SolutionTracks"]: """Factory fixture to create Tracks or SolutionTracks instances. Returns a factory function that can be called with: ndim: 3 for 2D spatial + time, 4 for 3D spatial + time - with_seg: Whether to include segmentation + with_seg: Whether to include segmentation (mask/bbox as node attributes) is_solution: Whether to return SolutionTracks instead of Tracks Example: @@ -208,9 +377,9 @@ def _make_tracks( # Determine which graph to use based on requirements if with_seg: - # With segmentation: use fully computed features (pos + track_id + area + iou) - graph = get_graph(ndim=ndim, with_features="computed") - seg = get_segmentation(ndim=ndim) + # With segmentation: use graph with mask/bbox node attrs + # and all computed features + graph = get_graph(ndim=ndim, with_features="segmentation") else: # Without segmentation if is_solution: @@ -219,7 +388,6 @@ def _make_tracks( else: # Regular Tracks: use graph with just pos graph = get_graph(ndim=ndim, with_features="position") - seg = None # Build FeatureDict based on what exists in the graph features_dict = { @@ -228,12 +396,12 @@ def _make_tracks( } if with_seg: - # Graph has pre-computed features (area, iou, track_id) + # Graph has pre-computed features (area, iou, track_id, mask, bbox) features_dict["area"] = Area(ndim=ndim) features_dict["iou"] = IoU() features_dict["track_id"] = TrackletID() elif is_solution: - # SolutionTracks without seg: has track_id but not area/iou + # SolutionTracks without seg: has track_id but not area/iou/mask/bbox features_dict["track_id"] = TrackletID() feature_dict = FeatureDict( @@ -247,14 +415,12 @@ def _make_tracks( if is_solution: return SolutionTracks( graph, - segmentation=seg, ndim=ndim, features=feature_dict, ) else: return Tracks( graph, - segmentation=seg, ndim=ndim, features=feature_dict, ) @@ -263,31 +429,32 @@ def _make_tracks( @pytest.fixture -def graph_2d_list() -> nx.DiGraph: - graph = nx.DiGraph() +def graph_2d_list(tmp_path) -> td.graph.GraphView: + db_path = str(tmp_path / "graph_2d_list.db") + graph = create_empty_graphview_graph(database=db_path) + nodes = [ - ( - 1, - { - "y": 100, - "x": 50, - "t": 0, - "area": 1245, - "track_id": 1, - }, - ), - ( - 2, - { - "y": 20, - "x": 100, - "t": 1, - "area": 500, - "track_id": 2, - }, - ), + { + "y": 100, + "x": 50, + "t": 0, + "area": 1245, + "track_id": 1, + }, + { + "y": 20, + "x": 100, + "t": 1, + "area": 500, + "track_id": 2, + }, ] - graph.add_nodes_from(nodes) + graph.add_node_attr_key("y", default_value=0.0, dtype=pl.Float64) + graph.add_node_attr_key("x", default_value=0.0, dtype=pl.Float64) + graph.add_node_attr_key("area", default_value=0.0, dtype=pl.Float64) + graph.add_node_attr_key("track_id", default_value=0.0, dtype=pl.Float64) + + graph.bulk_add_nodes(nodes=nodes, indices=[1, 2]) return graph @@ -300,33 +467,7 @@ def sphere(center, radius, shape): @pytest.fixture -def segmentation_3d() -> "NDArray[np.int32]": - frame_shape = (100, 100, 100) - total_shape = (5, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) - segmentation[0][mask] = 1 - - # make frame with two cells - # first cell centered at (20, 50, 80) with label 2 - # second cell centered at (60, 50, 45) with label 3 - mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) - segmentation[1][mask] = 2 - mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) - segmentation[1][mask] = 3 - - # continue track 3 with squares from 0 to 4 in x and y with label 3 - segmentation[2, 0:4, 0:4, 0:4] = 4 - segmentation[4, 0:4, 0:4, 0:4] = 5 - - # unconnected node - segmentation[4, 96:100, 96:100, 96:100] = 6 - return segmentation - - -@pytest.fixture -def get_graph(request) -> Callable[..., nx.DiGraph]: +def get_graph(request) -> Callable[..., td.graph.GraphView]: """Factory fixture to get graph by ndim and feature level. Args: @@ -335,16 +476,16 @@ def get_graph(request) -> Callable[..., nx.DiGraph]: - "clean": time only - "position": time + pos - "track_id": time + pos + track_id (for SolutionTracks without seg) - - "computed": time + pos + track_id + area + iou (full features) + - "segmentation": time + pos + track_id + area + iou + mask + bbox Returns: A deep copy of the requested graph Example: - graph = get_graph(ndim=3, with_features="track_id") + graph = get_graph(ndim=3, with_features="segmentation") """ - def _get_graph(ndim: int, with_features: str = "clean") -> nx.DiGraph: + def _get_graph(ndim: int, with_features: str = "clean") -> td.graph.GraphView: if with_features == "clean": graph = request.getfixturevalue("graph_clean") elif with_features == "position": @@ -357,41 +498,18 @@ def _get_graph(ndim: int, with_features: str = "clean") -> nx.DiGraph: graph = request.getfixturevalue("graph_2d_with_track_id") else: # ndim == 4 graph = request.getfixturevalue("graph_3d_with_track_id") - elif with_features == "computed": + elif with_features == "segmentation": if ndim == 3: - graph = request.getfixturevalue("graph_2d_with_computed_features") + graph = request.getfixturevalue("graph_2d_with_segmentation") else: # ndim == 4 - graph = request.getfixturevalue("graph_3d_with_computed_features") + graph = request.getfixturevalue("graph_3d_with_segmentation") else: raise ValueError( - f"with_features must be 'clean', 'position', 'track_id', or 'computed', " - f"got {with_features}" + f"with_features must be 'clean', 'position', 'track_id', " + f"or 'segmentation', got {with_features}" ) - # Return a deep copy to avoid fixture pollution - return copy.deepcopy(graph) + # Deepcopy alternative for tracksdata graph + return graph.detach().filter().subgraph() return _get_graph - - -@pytest.fixture -def get_segmentation(request) -> Callable[..., "NDArray[np.int32]"]: - """Factory fixture to get segmentation by ndim. - - Args: - ndim: 3 for 2D spatial + time, 4 for 3D spatial + time - - Returns: - The segmentation array (not copied since it's not typically modified) - - Example: - seg = get_segmentation(ndim=3) - """ - - def _get_segmentation(ndim: int) -> "NDArray[np.int32]": - if ndim == 3: - return request.getfixturevalue("segmentation_2d") - else: # ndim == 4 - return request.getfixturevalue("segmentation_3d") - - return _get_segmentation diff --git a/tests/data/format_v1/test_save_load_False_3_False_0/attrs.json b/tests/data/format_v1/test_save_load_False_3_False_0/attrs.json new file mode 100644 index 00000000..00ca9cd1 --- /dev/null +++ b/tests/data/format_v1/test_save_load_False_3_False_0/attrs.json @@ -0,0 +1 @@ +{"scale": null, "ndim": 3, "features": {"FeatureDict": {"features": {"t": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Time", "required": true, "default_value": null}, "pos": {"feature_type": "node", "value_type": "float", "num_values": 2, "display_name": "position", "value_names": ["y", "x"], "required": true, "default_value": null, "spatial_dims": true}}, "time_key": "t", "position_key": "pos", "tracklet_key": null}}} diff --git a/tests/data/format_v1/test_save_load_False_3_False_0/graph.json b/tests/data/format_v1/test_save_load_False_3_False_0/graph.json new file mode 100644 index 00000000..d0a025b5 --- /dev/null +++ b/tests/data/format_v1/test_save_load_False_3_False_0/graph.json @@ -0,0 +1 @@ +{"directed": true, "multigraph": false, "graph": {}, "nodes": [{"t": 0, "pos": [50, 50], "id": 1}, {"t": 1, "pos": [20, 80], "id": 2}, {"t": 1, "pos": [60, 45], "id": 3}, {"t": 2, "pos": [1.5, 1.5], "id": 4}, {"t": 4, "pos": [1.5, 1.5], "id": 5}, {"t": 4, "pos": [97.5, 97.5], "id": 6}], "links": [{"source": 1, "target": 2}, {"source": 1, "target": 3}, {"source": 3, "target": 4}, {"source": 4, "target": 5}]} diff --git a/tests/data/format_v1/test_save_load_False_3_True_0/attrs.json b/tests/data/format_v1/test_save_load_False_3_True_0/attrs.json new file mode 100644 index 00000000..8aa55337 --- /dev/null +++ b/tests/data/format_v1/test_save_load_False_3_True_0/attrs.json @@ -0,0 +1 @@ +{"scale": null, "ndim": 3, "features": {"FeatureDict": {"features": {"t": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Time", "required": true, "default_value": null}, "pos": {"feature_type": "node", "value_type": "float", "num_values": 2, "display_name": "position", "value_names": ["y", "x"], "required": true, "default_value": null, "spatial_dims": true}, "area": {"feature_type": "node", "value_type": "float", "num_values": 1, "display_name": "Area", "required": true, "default_value": null}, "iou": {"feature_type": "edge", "value_type": "float", "num_values": 1, "display_name": "IoU", "required": true, "default_value": null}, "track_id": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Tracklet ID", "required": true, "default_value": null}}, "time_key": "t", "position_key": "pos", "tracklet_key": "track_id"}}} diff --git a/tests/data/format_v1/test_save_load_False_3_True_0/graph.json b/tests/data/format_v1/test_save_load_False_3_True_0/graph.json new file mode 100644 index 00000000..024fe42b --- /dev/null +++ b/tests/data/format_v1/test_save_load_False_3_True_0/graph.json @@ -0,0 +1 @@ +{"directed": true, "multigraph": false, "graph": {}, "nodes": [{"t": 0, "pos": [50, 50], "track_id": 1, "area": 1245, "id": 1}, {"t": 1, "pos": [20, 80], "track_id": 2, "area": 305, "id": 2}, {"t": 1, "pos": [60, 45], "track_id": 3, "area": 697, "id": 3}, {"t": 2, "pos": [1.5, 1.5], "track_id": 3, "area": 16, "id": 4}, {"t": 4, "pos": [1.5, 1.5], "track_id": 3, "area": 16, "id": 5}, {"t": 4, "pos": [97.5, 97.5], "track_id": 5, "area": 16, "id": 6}], "links": [{"iou": 0.0, "source": 1, "target": 2}, {"iou": 0.395, "source": 1, "target": 3}, {"iou": 0.0, "source": 3, "target": 4}, {"iou": 1.0, "source": 4, "target": 5}]} diff --git a/tests/data/format_v1/test_save_load_False_3_True_0/seg.npy b/tests/data/format_v1/test_save_load_False_3_True_0/seg.npy new file mode 100644 index 00000000..64406223 Binary files /dev/null and b/tests/data/format_v1/test_save_load_False_3_True_0/seg.npy differ diff --git a/tests/data/format_v1/test_save_load_False_4_False_0/attrs.json b/tests/data/format_v1/test_save_load_False_4_False_0/attrs.json new file mode 100644 index 00000000..0d7d20ba --- /dev/null +++ b/tests/data/format_v1/test_save_load_False_4_False_0/attrs.json @@ -0,0 +1 @@ +{"scale": null, "ndim": 4, "features": {"FeatureDict": {"features": {"t": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Time", "required": true, "default_value": null}, "pos": {"feature_type": "node", "value_type": "float", "num_values": 3, "display_name": "position", "value_names": ["z", "y", "x"], "required": true, "default_value": null, "spatial_dims": true}}, "time_key": "t", "position_key": "pos", "tracklet_key": null}}} diff --git a/tests/data/format_v1/test_save_load_False_4_False_0/graph.json b/tests/data/format_v1/test_save_load_False_4_False_0/graph.json new file mode 100644 index 00000000..4d000548 --- /dev/null +++ b/tests/data/format_v1/test_save_load_False_4_False_0/graph.json @@ -0,0 +1 @@ +{"directed": true, "multigraph": false, "graph": {}, "nodes": [{"t": 0, "pos": [50, 50, 50], "id": 1}, {"t": 1, "pos": [20, 50, 80], "id": 2}, {"t": 1, "pos": [60, 50, 45], "id": 3}, {"t": 2, "pos": [1.5, 1.5, 1.5], "id": 4}, {"t": 4, "pos": [1.5, 1.5, 1.5], "id": 5}, {"t": 4, "pos": [97.5, 97.5, 97.5], "id": 6}], "links": [{"source": 1, "target": 2}, {"source": 1, "target": 3}, {"source": 3, "target": 4}, {"source": 4, "target": 5}]} diff --git a/tests/data/format_v1/test_save_load_False_4_True_0/attrs.json b/tests/data/format_v1/test_save_load_False_4_True_0/attrs.json new file mode 100644 index 00000000..2325b0ec --- /dev/null +++ b/tests/data/format_v1/test_save_load_False_4_True_0/attrs.json @@ -0,0 +1 @@ +{"scale": null, "ndim": 4, "features": {"FeatureDict": {"features": {"t": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Time", "required": true, "default_value": null}, "pos": {"feature_type": "node", "value_type": "float", "num_values": 3, "display_name": "position", "value_names": ["z", "y", "x"], "required": true, "default_value": null, "spatial_dims": true}, "area": {"feature_type": "node", "value_type": "float", "num_values": 1, "display_name": "Volume", "required": true, "default_value": null}, "iou": {"feature_type": "edge", "value_type": "float", "num_values": 1, "display_name": "IoU", "required": true, "default_value": null}, "track_id": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Tracklet ID", "required": true, "default_value": null}}, "time_key": "t", "position_key": "pos", "tracklet_key": "track_id"}}} diff --git a/tests/data/format_v1/test_save_load_False_4_True_0/graph.json b/tests/data/format_v1/test_save_load_False_4_True_0/graph.json new file mode 100644 index 00000000..380d5887 --- /dev/null +++ b/tests/data/format_v1/test_save_load_False_4_True_0/graph.json @@ -0,0 +1 @@ +{"directed": true, "multigraph": false, "graph": {}, "nodes": [{"t": 0, "pos": [50, 50, 50], "track_id": 1, "area": 33401, "id": 1}, {"t": 1, "pos": [20, 50, 80], "track_id": 2, "area": 4169, "id": 2}, {"t": 1, "pos": [60, 50, 45], "track_id": 3, "area": 14147, "id": 3}, {"t": 2, "pos": [1.5, 1.5, 1.5], "track_id": 3, "area": 64, "id": 4}, {"t": 4, "pos": [1.5, 1.5, 1.5], "track_id": 3, "area": 64, "id": 5}, {"t": 4, "pos": [97.5, 97.5, 97.5], "track_id": 5, "area": 64, "id": 6}], "links": [{"iou": 0.0, "source": 1, "target": 2}, {"iou": 0.302, "source": 1, "target": 3}, {"iou": 0.0, "source": 3, "target": 4}, {"iou": 1.0, "source": 4, "target": 5}]} diff --git a/tests/data/format_v1/test_save_load_False_4_True_0/seg.npy b/tests/data/format_v1/test_save_load_False_4_True_0/seg.npy new file mode 100644 index 00000000..236d56ee Binary files /dev/null and b/tests/data/format_v1/test_save_load_False_4_True_0/seg.npy differ diff --git a/tests/data/format_v1/test_save_load_True_3_False_0/attrs.json b/tests/data/format_v1/test_save_load_True_3_False_0/attrs.json new file mode 100644 index 00000000..faee92fe --- /dev/null +++ b/tests/data/format_v1/test_save_load_True_3_False_0/attrs.json @@ -0,0 +1 @@ +{"scale": null, "ndim": 3, "features": {"FeatureDict": {"features": {"t": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Time", "required": true, "default_value": null}, "pos": {"feature_type": "node", "value_type": "float", "num_values": 2, "display_name": "position", "value_names": ["y", "x"], "required": true, "default_value": null, "spatial_dims": true}, "track_id": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Tracklet ID", "required": true, "default_value": null}}, "time_key": "t", "position_key": "pos", "tracklet_key": "track_id"}}} diff --git a/tests/data/format_v1/test_save_load_True_3_False_0/graph.json b/tests/data/format_v1/test_save_load_True_3_False_0/graph.json new file mode 100644 index 00000000..b878f1e3 --- /dev/null +++ b/tests/data/format_v1/test_save_load_True_3_False_0/graph.json @@ -0,0 +1 @@ +{"directed": true, "multigraph": false, "graph": {}, "nodes": [{"t": 0, "pos": [50, 50], "track_id": 1, "id": 1}, {"t": 1, "pos": [20, 80], "track_id": 2, "id": 2}, {"t": 1, "pos": [60, 45], "track_id": 3, "id": 3}, {"t": 2, "pos": [1.5, 1.5], "track_id": 3, "id": 4}, {"t": 4, "pos": [1.5, 1.5], "track_id": 3, "id": 5}, {"t": 4, "pos": [97.5, 97.5], "track_id": 5, "id": 6}], "links": [{"source": 1, "target": 2}, {"source": 1, "target": 3}, {"source": 3, "target": 4}, {"source": 4, "target": 5}]} diff --git a/tests/data/format_v1/test_save_load_True_3_True_0/attrs.json b/tests/data/format_v1/test_save_load_True_3_True_0/attrs.json new file mode 100644 index 00000000..8aa55337 --- /dev/null +++ b/tests/data/format_v1/test_save_load_True_3_True_0/attrs.json @@ -0,0 +1 @@ +{"scale": null, "ndim": 3, "features": {"FeatureDict": {"features": {"t": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Time", "required": true, "default_value": null}, "pos": {"feature_type": "node", "value_type": "float", "num_values": 2, "display_name": "position", "value_names": ["y", "x"], "required": true, "default_value": null, "spatial_dims": true}, "area": {"feature_type": "node", "value_type": "float", "num_values": 1, "display_name": "Area", "required": true, "default_value": null}, "iou": {"feature_type": "edge", "value_type": "float", "num_values": 1, "display_name": "IoU", "required": true, "default_value": null}, "track_id": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Tracklet ID", "required": true, "default_value": null}}, "time_key": "t", "position_key": "pos", "tracklet_key": "track_id"}}} diff --git a/tests/data/format_v1/test_save_load_True_3_True_0/graph.json b/tests/data/format_v1/test_save_load_True_3_True_0/graph.json new file mode 100644 index 00000000..024fe42b --- /dev/null +++ b/tests/data/format_v1/test_save_load_True_3_True_0/graph.json @@ -0,0 +1 @@ +{"directed": true, "multigraph": false, "graph": {}, "nodes": [{"t": 0, "pos": [50, 50], "track_id": 1, "area": 1245, "id": 1}, {"t": 1, "pos": [20, 80], "track_id": 2, "area": 305, "id": 2}, {"t": 1, "pos": [60, 45], "track_id": 3, "area": 697, "id": 3}, {"t": 2, "pos": [1.5, 1.5], "track_id": 3, "area": 16, "id": 4}, {"t": 4, "pos": [1.5, 1.5], "track_id": 3, "area": 16, "id": 5}, {"t": 4, "pos": [97.5, 97.5], "track_id": 5, "area": 16, "id": 6}], "links": [{"iou": 0.0, "source": 1, "target": 2}, {"iou": 0.395, "source": 1, "target": 3}, {"iou": 0.0, "source": 3, "target": 4}, {"iou": 1.0, "source": 4, "target": 5}]} diff --git a/tests/data/format_v1/test_save_load_True_3_True_0/seg.npy b/tests/data/format_v1/test_save_load_True_3_True_0/seg.npy new file mode 100644 index 00000000..64406223 Binary files /dev/null and b/tests/data/format_v1/test_save_load_True_3_True_0/seg.npy differ diff --git a/tests/data/format_v1/test_save_load_True_4_False_0/attrs.json b/tests/data/format_v1/test_save_load_True_4_False_0/attrs.json new file mode 100644 index 00000000..98c52606 --- /dev/null +++ b/tests/data/format_v1/test_save_load_True_4_False_0/attrs.json @@ -0,0 +1 @@ +{"scale": null, "ndim": 4, "features": {"FeatureDict": {"features": {"t": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Time", "required": true, "default_value": null}, "pos": {"feature_type": "node", "value_type": "float", "num_values": 3, "display_name": "position", "value_names": ["z", "y", "x"], "required": true, "default_value": null, "spatial_dims": true}, "track_id": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Tracklet ID", "required": true, "default_value": null}}, "time_key": "t", "position_key": "pos", "tracklet_key": "track_id"}}} diff --git a/tests/data/format_v1/test_save_load_True_4_False_0/graph.json b/tests/data/format_v1/test_save_load_True_4_False_0/graph.json new file mode 100644 index 00000000..a82a72f0 --- /dev/null +++ b/tests/data/format_v1/test_save_load_True_4_False_0/graph.json @@ -0,0 +1 @@ +{"directed": true, "multigraph": false, "graph": {}, "nodes": [{"t": 0, "pos": [50, 50, 50], "track_id": 1, "id": 1}, {"t": 1, "pos": [20, 50, 80], "track_id": 2, "id": 2}, {"t": 1, "pos": [60, 50, 45], "track_id": 3, "id": 3}, {"t": 2, "pos": [1.5, 1.5, 1.5], "track_id": 3, "id": 4}, {"t": 4, "pos": [1.5, 1.5, 1.5], "track_id": 3, "id": 5}, {"t": 4, "pos": [97.5, 97.5, 97.5], "track_id": 5, "id": 6}], "links": [{"source": 1, "target": 2}, {"source": 1, "target": 3}, {"source": 3, "target": 4}, {"source": 4, "target": 5}]} diff --git a/tests/data/format_v1/test_save_load_True_4_True_0/attrs.json b/tests/data/format_v1/test_save_load_True_4_True_0/attrs.json new file mode 100644 index 00000000..2325b0ec --- /dev/null +++ b/tests/data/format_v1/test_save_load_True_4_True_0/attrs.json @@ -0,0 +1 @@ +{"scale": null, "ndim": 4, "features": {"FeatureDict": {"features": {"t": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Time", "required": true, "default_value": null}, "pos": {"feature_type": "node", "value_type": "float", "num_values": 3, "display_name": "position", "value_names": ["z", "y", "x"], "required": true, "default_value": null, "spatial_dims": true}, "area": {"feature_type": "node", "value_type": "float", "num_values": 1, "display_name": "Volume", "required": true, "default_value": null}, "iou": {"feature_type": "edge", "value_type": "float", "num_values": 1, "display_name": "IoU", "required": true, "default_value": null}, "track_id": {"feature_type": "node", "value_type": "int", "num_values": 1, "display_name": "Tracklet ID", "required": true, "default_value": null}}, "time_key": "t", "position_key": "pos", "tracklet_key": "track_id"}}} diff --git a/tests/data/format_v1/test_save_load_True_4_True_0/graph.json b/tests/data/format_v1/test_save_load_True_4_True_0/graph.json new file mode 100644 index 00000000..380d5887 --- /dev/null +++ b/tests/data/format_v1/test_save_load_True_4_True_0/graph.json @@ -0,0 +1 @@ +{"directed": true, "multigraph": false, "graph": {}, "nodes": [{"t": 0, "pos": [50, 50, 50], "track_id": 1, "area": 33401, "id": 1}, {"t": 1, "pos": [20, 50, 80], "track_id": 2, "area": 4169, "id": 2}, {"t": 1, "pos": [60, 50, 45], "track_id": 3, "area": 14147, "id": 3}, {"t": 2, "pos": [1.5, 1.5, 1.5], "track_id": 3, "area": 64, "id": 4}, {"t": 4, "pos": [1.5, 1.5, 1.5], "track_id": 3, "area": 64, "id": 5}, {"t": 4, "pos": [97.5, 97.5, 97.5], "track_id": 5, "area": 64, "id": 6}], "links": [{"iou": 0.0, "source": 1, "target": 2}, {"iou": 0.302, "source": 1, "target": 3}, {"iou": 0.0, "source": 3, "target": 4}, {"iou": 1.0, "source": 4, "target": 5}]} diff --git a/tests/data/format_v1/test_save_load_True_4_True_0/seg.npy b/tests/data/format_v1/test_save_load_True_4_True_0/seg.npy new file mode 100644 index 00000000..236d56ee Binary files /dev/null and b/tests/data/format_v1/test_save_load_True_4_True_0/seg.npy differ diff --git a/tests/data_model/__init__.py b/tests/data_model/__init__.py new file mode 100644 index 00000000..315c8552 --- /dev/null +++ b/tests/data_model/__init__.py @@ -0,0 +1,2 @@ +# This file makes the tests/data_model directory a Python package +# to support relative imports diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 3994d329..6644a06c 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -1,35 +1,36 @@ -import networkx as nx import numpy as np from funtracks.actions import AddNode from funtracks.data_model import SolutionTracks, Tracks +from funtracks.user_actions import UserUpdateSegmentation +from funtracks.utils.tracksdata_utils import create_empty_graphview_graph track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} -def test_recompute_track_ids(graph_2d_with_position): +def test_recompute_track_ids(graph_2d_with_track_id): tracks = SolutionTracks( - graph_2d_with_position, + graph_2d_with_track_id, ndim=3, **track_attrs, ) - assert tracks.get_next_track_id() == 5 + assert tracks.get_next_track_id() == 6 -def test_next_track_id(graph_2d_with_computed_features): - tracks = SolutionTracks(graph_2d_with_computed_features, ndim=3, **track_attrs) +def test_next_track_id(graph_2d_with_segmentation): + tracks = SolutionTracks(graph_2d_with_segmentation, ndim=3, **track_attrs) assert tracks.get_next_track_id() == 6 AddNode( tracks, node=10, - attributes={"t": 3, "pos": [0, 0, 0, 0], "track_id": 10}, + attributes={"t": 3, "pos": [0, 0], "track_id": 10}, ) assert tracks.get_next_track_id() == 11 -def test_from_tracks_cls(graph_2d_with_computed_features): +def test_from_tracks_cls(graph_2d_with_segmentation): tracks = Tracks( - graph_2d_with_computed_features, + graph_2d_with_segmentation, ndim=3, pos_attr="POSITION", time_attr="TIME", @@ -46,18 +47,18 @@ def test_from_tracks_cls(graph_2d_with_computed_features): assert solution_tracks.get_node_attr(6, tracks.features.tracklet_key) == 5 -def test_from_tracks_cls_recompute(graph_2d_with_computed_features): +def test_from_tracks_cls_recompute(graph_2d_with_segmentation): tracks = Tracks( - graph_2d_with_computed_features, + graph_2d_with_segmentation, ndim=3, pos_attr="POSITION", time_attr="TIME", tracklet_attr=track_attrs["tracklet_attr"], scale=(2, 2, 2), ) - # delete track id on one node triggers reassignment of track_ids even when recompute - # is False. - tracks.graph.nodes[1].pop(tracks.features.tracklet_key, None) + # delete track id (default value -1) on one node triggers reassignment of + # track_ids even when recompute is False. + tracks.graph[1][tracks.features.tracklet_key] = -1 solution_tracks = SolutionTracks.from_tracks(tracks) # should have reassigned new track_id to node 6 assert solution_tracks.get_node_attr(6, solution_tracks.features.tracklet_key) == 4 @@ -66,41 +67,60 @@ def test_from_tracks_cls_recompute(graph_2d_with_computed_features): ) # still 1 +def test_update_segmentation(graph_2d_with_segmentation): + tracks = SolutionTracks( + graph_2d_with_segmentation, + ndim=3, + **track_attrs, + ) + pix = tracks.get_pixels(1) + assert isinstance(pix, tuple) + UserUpdateSegmentation( + tracks, + new_value=99, + updated_pixels=[(pix, 0)], + current_track_id=6, + ) + assert np.asarray(tracks.segmentation)[0, 50, 50] == 99 + + def test_next_track_id_empty(): - graph = nx.DiGraph() - seg = np.zeros(shape=(10, 100, 100, 100), dtype=np.uint64) - tracks = SolutionTracks(graph, segmentation=seg, **track_attrs) + graph = create_empty_graphview_graph( + node_attributes=["pos", "track_id"], + edge_attributes=[], + ) + tracks = SolutionTracks(graph, ndim=4, **track_attrs) assert tracks.get_next_track_id() == 1 def test_export_to_csv_with_display_names( - graph_2d_with_computed_features, graph_3d_with_computed_features, tmp_path + graph_2d_with_segmentation, graph_3d_with_segmentation, tmp_path ): """Test CSV export with use_display_names=True option.""" from funtracks.import_export import export_to_csv # Test 2D with display names - tracks = SolutionTracks(graph_2d_with_computed_features, **track_attrs, ndim=3) + tracks = SolutionTracks(graph_2d_with_segmentation, **track_attrs, ndim=3) temp_file = tmp_path / "test_export_2d_display.csv" export_to_csv(tracks, temp_file, use_display_names=True) with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header + assert len(lines) == tracks.graph.num_nodes() + 1 # add header - # With display names: ID, Parent ID, Time, y, x, Tracklet ID - header = ["ID", "Parent ID", "Time", "y", "x", "Tracklet ID"] + # With display names: ID, Parent ID, Time, y, x, Area, Tracklet ID + header = ["ID", "Parent ID", "Time", "y", "x", "Area", "Tracklet ID"] assert lines[0].strip().split(",") == header # Test 3D with display names - tracks = SolutionTracks(graph_3d_with_computed_features, **track_attrs, ndim=4) + tracks = SolutionTracks(graph_3d_with_segmentation, **track_attrs, ndim=4) temp_file = tmp_path / "test_export_3d_display.csv" export_to_csv(tracks, temp_file, use_display_names=True) with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header + assert len(lines) == tracks.graph.num_nodes() + 1 # add header # With display names: ID, Parent ID, Time, z, y, x, Tracklet ID - header = ["ID", "Parent ID", "Time", "z", "y", "x", "Tracklet ID"] + header = ["ID", "Parent ID", "Time", "z", "y", "x", "Volume", "Tracklet ID"] assert lines[0].strip().split(",") == header diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 3ba2e5a7..45b68158 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -1,15 +1,21 @@ -import networkx as nx import numpy as np +import polars as pl import pytest +import tracksdata as td from funtracks.data_model import Tracks +from funtracks.user_actions import UserUpdateSegmentation +from funtracks.utils.tracksdata_utils import ( + create_empty_graphview_graph, +) track_attrs = {"time_attr": "t", "tracklet_attr": "track_id"} -def test_create_tracks(graph_3d_with_computed_features: nx.DiGraph, segmentation_3d): +def test_create_tracks(graph_3d_with_segmentation: td.graph.GraphView): # create empty tracks - tracks = Tracks(graph=nx.DiGraph(), ndim=3, **track_attrs) # type: ignore[arg-type] + empty_graph = create_empty_graphview_graph() + tracks = Tracks(graph=empty_graph, ndim=3, **track_attrs) # type: ignore[arg-type] assert tracks.features.position_key == "pos" assert isinstance(tracks.features["pos"], dict) with pytest.raises(KeyError): @@ -17,7 +23,7 @@ def test_create_tracks(graph_3d_with_computed_features: nx.DiGraph, segmentation # create tracks with graph only tracks = Tracks( - graph=graph_3d_with_computed_features, + graph=graph_3d_with_segmentation, ndim=4, **track_attrs, # type: ignore[arg-type] ) @@ -32,8 +38,7 @@ def test_create_tracks(graph_3d_with_computed_features: nx.DiGraph, segmentation # create track with graph and seg tracks = Tracks( - graph=graph_3d_with_computed_features, - segmentation=segmentation_3d, + graph=graph_3d_with_segmentation, **track_attrs, # type: ignore[arg-type] ) pos_key = tracks.features.position_key @@ -43,35 +48,37 @@ def test_create_tracks(graph_3d_with_computed_features: nx.DiGraph, segmentation assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 assert tracks.get_positions([1], incl_time=True).tolist() == [[0, 50, 50, 50]] - tracks._set_node_attr(1, tracks.features.time_key, 1) - assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] + # TODO: Explicitly block doing setting the time + # tracks._set_node_attr(1, tracks.features.time_key, 1) + # assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] tracks_wrong_attr = Tracks( - graph=graph_3d_with_computed_features, - segmentation=segmentation_3d, + graph=graph_3d_with_segmentation, time_attr="test", ) with pytest.raises(KeyError): # raises error at access if time is wrong tracks_wrong_attr.get_times([1]) - tracks_wrong_attr = Tracks( - graph=graph_3d_with_computed_features, pos_attr="test", ndim=3 - ) - with pytest.raises(KeyError): # raises error at access if pos is wrong - tracks_wrong_attr.get_positions([1]) + with pytest.raises(ValueError): + # Raise error is segmentation shape does not match provided ndim + tracks_wrong_attr = Tracks( + graph=graph_3d_with_segmentation, pos_attr="test", ndim=3 + ) # test multiple position attrs pos_attr = ("z", "y", "x") - for node in graph_3d_with_computed_features.nodes(): - pos = graph_3d_with_computed_features.nodes[node]["pos"] + graph_3d_with_segmentation.add_node_attr_key("z", default_value=0.0, dtype=pl.Float64) + graph_3d_with_segmentation.add_node_attr_key("y", default_value=0.0, dtype=pl.Float64) + graph_3d_with_segmentation.add_node_attr_key("x", default_value=0.0, dtype=pl.Float64) + for node in graph_3d_with_segmentation.node_ids(): + pos = graph_3d_with_segmentation[node]["pos"] z, y, x = pos - del graph_3d_with_computed_features.nodes[node]["pos"] - graph_3d_with_computed_features.nodes[node]["z"] = z - graph_3d_with_computed_features.nodes[node]["y"] = y - graph_3d_with_computed_features.nodes[node]["x"] = x + graph_3d_with_segmentation[node]["z"] = z + graph_3d_with_segmentation[node]["y"] = y + graph_3d_with_segmentation[node]["x"] = x tracks = Tracks( - graph=graph_3d_with_computed_features, + graph=graph_3d_with_segmentation, pos_attr=pos_attr, ndim=4, **track_attrs, # type: ignore[arg-type] @@ -81,77 +88,65 @@ def test_create_tracks(graph_3d_with_computed_features: nx.DiGraph, segmentation assert tracks.get_position(1) == [55, 56, 57] -def test_pixels_and_seg_id(graph_3d_with_computed_features, segmentation_3d): - # create track with graph and seg - tracks = Tracks( - graph=graph_3d_with_computed_features, segmentation=segmentation_3d, **track_attrs - ) - - # changing a segmentation id changes it in the mapping - pix = tracks.get_pixels(1) - new_seg_id = 10 - tracks.set_pixels(pix, new_seg_id) - - -def test_nodes_edges(graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) +def test_nodes_edges(graph_2d_with_segmentation): + tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6} - assert set(map(tuple, tracks.edges())) == {(1, 2), (1, 3), (3, 4), (4, 5)} + assert set(tracks.edges()) == {1, 2, 3, 4} + assert set(map(tuple, tracks.graph.edge_list())) == { + (1, 2), + (1, 3), + (3, 4), + (4, 5), + } -def test_degrees(graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) +def test_degrees(graph_2d_with_segmentation): + tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) assert tracks.in_degree(np.array([1])) == 0 assert tracks.in_degree(np.array([4])) == 1 - assert np.array_equal( - tracks.in_degree(None), np.array([[1, 0], [2, 1], [3, 1], [4, 1], [5, 1], [6, 0]]) - ) + assert np.array_equal(tracks.in_degree(None), np.array([0, 1, 1, 1, 1, 0])) assert np.array_equal(tracks.out_degree(np.array([1, 4])), np.array([2, 1])) assert np.array_equal( tracks.out_degree(None), - np.array([[1, 2], [2, 0], [3, 1], [4, 1], [5, 0], [6, 0]]), + np.array([2, 0, 1, 1, 0, 0]), ) -def test_predecessors_successors(graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) +def test_predecessors_successors(graph_2d_with_segmentation): + tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) assert tracks.predecessors(2) == [1] - assert tracks.successors(1) == [2, 3] + assert set(tracks.successors(1)) == {2, 3} assert tracks.predecessors(1) == [] assert tracks.successors(2) == [] -def test_get_set_node_attr(graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) +def test_get_set_node_attr(graph_2d_with_segmentation): + tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) - tracks._set_node_attr(1, "a", 42) + tracks._set_node_attr(1, "area", 42) - tracks._set_nodes_attr([1, 2], "b", [7, 8]) - assert tracks.get_node_attr(1, "a", required=True) == 42 - assert tracks.get_nodes_attr([1, 2], "b", required=True) == [7, 8] - assert tracks.get_nodes_attr([1, 2], "b", required=False) == [7, 8] + tracks._set_nodes_attr([1, 2], "track_id", [7, 8]) + assert tracks.get_node_attr(1, "area", required=True) == 42 + assert tracks.get_nodes_attr([1, 2], "track_id", required=True) == [7, 8] + assert tracks.get_nodes_attr([1, 2], "track_id", required=False) == [7, 8] with pytest.raises(KeyError): tracks.get_node_attr(1, "not_present", required=True) - assert tracks.get_node_attr(1, "not_present", required=False) is None with pytest.raises(KeyError): tracks.get_nodes_attr([1, 2], "not_present", required=True) - assert all( - x is None for x in tracks.get_nodes_attr([1, 2], "not_present", required=False) - ) # test array attributes - tracks._set_node_attr(1, "array_attr", np.array([1, 2, 3])) - tracks._set_nodes_attr((1, 2), "array_attr2", np.array(([1, 2, 3], [4, 5, 6]))) - - -def test_get_set_edge_attr(graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) - tracks._set_edge_attr((1, 2), "c", 99) - assert tracks.get_edge_attr((1, 2), "c") == 99 - assert tracks.get_edge_attr((1, 2), "iou", required=True) == 0.0 - tracks._set_edges_attr([(1, 2), (1, 3)], "d", [123, 5]) - assert tracks.get_edges_attr([(1, 2), (1, 3)], "d", required=True) == [123, 5] - assert tracks.get_edges_attr([(1, 2), (1, 3)], "d", required=False) == [123, 5] + tracks._set_node_attr(1, "pos", np.array([1, 2])) + tracks._set_nodes_attr((1, 2), "pos", np.array(([1, 2], [4, 5]))) + + +def test_get_set_edge_attr(graph_2d_with_segmentation): + tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) + tracks._set_edge_attr((1, 2), "iou", 99) + assert tracks.get_edge_attr((1, 2), "iou") == 99 + assert tracks.get_edge_attr((1, 2), "iou", required=True) == 99 + tracks._set_edges_attr([(1, 2), (1, 3)], "iou", [123, 5]) + assert tracks.get_edges_attr([(1, 2), (1, 3)], "iou", required=True) == [123, 5] + assert tracks.get_edges_attr([(1, 2), (1, 3)], "iou", required=False) == [123, 5] with pytest.raises(KeyError): tracks.get_edge_attr((1, 2), "not_present", required=True) assert tracks.get_edge_attr((1, 2), "not_present", required=False) is None @@ -163,8 +158,8 @@ def test_get_set_edge_attr(graph_2d_with_computed_features): ) -def test_set_positions_str(graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3, **track_attrs) +def test_set_positions_str(graph_2d_with_segmentation): + tracks = Tracks(graph_2d_with_segmentation, ndim=3, **track_attrs) tracks.set_positions((1, 2), [(1, 2), (3, 4)]) assert np.array_equal( tracks.get_positions((1, 2), incl_time=False), np.array([[1, 2], [3, 4]]) @@ -189,39 +184,28 @@ def test_set_positions_list(graph_2d_list): ) -def test_get_pixels_and_set_pixels(graph_2d_with_computed_features, segmentation_2d): - tracks = Tracks( - graph_2d_with_computed_features, segmentation_2d, ndim=3, **track_attrs - ) - pix = tracks.get_pixels(1) - assert isinstance(pix, tuple) - tracks.set_pixels(pix, 99) - assert tracks.segmentation[0, 50, 50] == 99 - - -def test_get_pixels_none(graph_2d_with_computed_features): - tracks = Tracks( - graph_2d_with_computed_features, segmentation=None, ndim=3, **track_attrs - ) - assert tracks.get_pixels([1]) is None +def test_get_pixels_none(graph_2d_with_track_id): + tracks = Tracks(graph_2d_with_track_id, ndim=3, **track_attrs) + assert tracks.get_pixels(1) is None -def test_set_pixels_no_segmentation(graph_2d_with_computed_features): - tracks = Tracks( - graph_2d_with_computed_features, segmentation=None, ndim=3, **track_attrs - ) +def test_set_pixels_no_segmentation(graph_2d_with_track_id): + tracks = Tracks(graph_2d_with_track_id, ndim=3, **track_attrs) pix = [(np.array([0]), np.array([10]), np.array([20]))] + # set_pixels no longer exist, so we use UserUpdateSegmentation with pytest.raises(ValueError): - tracks.set_pixels(pix, [1]) + UserUpdateSegmentation( + tracks, + new_value=1, + updated_pixels=[(pix, 1)], + current_track_id=1, + ) def test_compute_ndim_errors(): - g = nx.DiGraph() - g.add_node(1, time=0, pos=[0, 0, 0]) - # seg ndim = 3, scale ndim = 2, provided ndim = 4 -> mismatch - seg = np.zeros((2, 2, 2)) - with pytest.raises(ValueError, match="Dimensions from segmentation"): - Tracks(g, segmentation=seg, scale=[1, 2], ndim=4) + g = create_empty_graphview_graph() + g.add_node_attr_key("pos", default_value=[0, 0], dtype=pl.List(pl.Int64)) + g.add_node(index=1, attrs={"t": 0, "pos": [0, 0, 0], "solution": True}) with pytest.raises( ValueError, match="Cannot compute dimensions from segmentation or scale" diff --git a/tests/features/__init__.py b/tests/features/__init__.py new file mode 100644 index 00000000..8a122a91 --- /dev/null +++ b/tests/features/__init__.py @@ -0,0 +1,2 @@ +# This file makes the tests/features directory a Python package +# to support relative imports diff --git a/tests/import_export/__init__.py b/tests/import_export/__init__.py new file mode 100644 index 00000000..7b33f386 --- /dev/null +++ b/tests/import_export/__init__.py @@ -0,0 +1,2 @@ +# This file makes the tests/import_export directory a Python package +# to support relative imports diff --git a/tests/import_export/test_csv_export.py b/tests/import_export/test_csv_export.py index 9786f476..5f3db961 100644 --- a/tests/import_export/test_csv_export.py +++ b/tests/import_export/test_csv_export.py @@ -20,14 +20,14 @@ def test_export_solution_to_csv(get_tracks, tmp_path, ndim, expected_header): with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header + assert len(lines) == tracks.graph.num_nodes() + 1 # add header assert lines[0].strip().split(",") == expected_header # Check first data line (node 1: t=0, pos=[50, 50] or [50, 50, 50], track_id=1) if ndim == 3: - expected_line1 = ["0", "50", "50", "1", "", "1"] + expected_line1 = ["0", "50.0", "50.0", "1", "", "1"] else: - expected_line1 = ["0", "50", "50", "50", "1", "", "1"] + expected_line1 = ["0", "50.0", "50.0", "50.0", "1", "", "1"] assert lines[1].strip().split(",") == expected_line1 diff --git a/tests/import_export/test_csv_import.py b/tests/import_export/test_csv_import.py index f2988f60..241948f9 100644 --- a/tests/import_export/test_csv_import.py +++ b/tests/import_export/test_csv_import.py @@ -43,8 +43,8 @@ def test_import_2d(self, simple_df_2d): tracks = tracks_from_df(simple_df_2d) assert isinstance(tracks, SolutionTracks) - assert tracks.graph.number_of_nodes() == 4 - assert tracks.graph.number_of_edges() == 3 + assert tracks.graph.num_nodes() == 4 + assert tracks.graph.num_edges() == 3 assert tracks.ndim == 3 def test_import_3d(self, df_3d): @@ -52,7 +52,7 @@ def test_import_3d(self, df_3d): tracks = tracks_from_df(df_3d) assert tracks.ndim == 4 - assert tracks.graph.number_of_nodes() == 3 + assert tracks.graph.num_nodes() == 3 # Check z coordinate pos = tracks.get_position(1) assert len(pos) == 3 # z, y, x @@ -123,7 +123,7 @@ def test_seg_id_matches_id(self, simple_df_2d): tracks = tracks_from_df(df, seg) assert tracks.segmentation is not None # Segmentation should not be relabeled - assert tracks.segmentation[0, 10, 15] == 1 + assert np.asarray(tracks.segmentation)[0, 10, 15] == 1 class TestEdgeCases: @@ -143,8 +143,8 @@ def test_single_node(self): tracks = tracks_from_df(df) - assert tracks.graph.number_of_nodes() == 1 - assert tracks.graph.number_of_edges() == 0 + assert tracks.graph.num_nodes() == 1 + assert tracks.graph.num_edges() == 0 def test_multiple_roots(self): """Test multiple independent lineages.""" @@ -160,11 +160,11 @@ def test_multiple_roots(self): tracks = tracks_from_df(df) - assert tracks.graph.number_of_nodes() == 4 - assert tracks.graph.number_of_edges() == 2 + assert tracks.graph.num_nodes() == 4 + assert tracks.graph.num_edges() == 2 # Should have two root nodes - roots = [n for n in tracks.graph.nodes() if tracks.graph.in_degree(n) == 0] + roots = [n for n in tracks.graph.node_ids() if tracks.graph.in_degree(n) == 0] assert len(roots) == 2 def test_division_event(self): @@ -181,8 +181,8 @@ def test_division_event(self): tracks = tracks_from_df(df) - assert tracks.graph.number_of_nodes() == 3 - assert tracks.graph.number_of_edges() == 2 + assert tracks.graph.num_nodes() == 3 + assert tracks.graph.num_edges() == 2 # Node 1 should have two children children = list(tracks.graph.successors(1)) @@ -203,15 +203,17 @@ def test_long_track(self): tracks = tracks_from_df(df) - assert tracks.graph.number_of_nodes() == 10 - assert tracks.graph.number_of_edges() == 9 + assert tracks.graph.num_nodes() == 10 + assert tracks.graph.num_edges() == 9 # Should form a single linear chain - roots = [n for n in tracks.graph.nodes() if tracks.graph.in_degree(n) == 0] + roots = [n for n in tracks.graph.node_ids() if tracks.graph.in_degree(n) == 0] assert len(roots) == 1 # Each non-leaf node should have exactly one child - non_leaves = [n for n in tracks.graph.nodes() if tracks.graph.out_degree(n) > 0] + non_leaves = [ + n for n in tracks.graph.node_ids() if tracks.graph.out_degree(n) > 0 + ] for node in non_leaves: assert tracks.graph.out_degree(node) == 1 @@ -330,8 +332,8 @@ def test_seg_id_same_as_id(self, simple_df_2d): tracks = tracks_from_df(simple_df_2d, node_name_map=name_map) # Both id and seg_id should be present with same values - assert tracks.graph.number_of_nodes() == 4 - for node_id in tracks.graph.nodes(): + assert tracks.graph.num_nodes() == 4 + for node_id in tracks.graph.node_ids(): assert tracks.get_node_attr(node_id, "seg_id") == node_id def test_duplicate_mapping_with_segmentation(self, simple_df_2d): @@ -354,7 +356,7 @@ def test_duplicate_mapping_with_segmentation(self, simple_df_2d): assert tracks.segmentation is not None # Segmentation should not be relabeled since seg_id == id - assert tracks.segmentation[0, 10, 15] == 1 + assert np.asarray(tracks.segmentation)[0, 10, 15] == 1 class TestValidationErrors: @@ -623,7 +625,7 @@ def test_empty_list_in_name_map_removed(self): tracks = tracks_from_df(df, node_name_map=name_map) assert tracks is not None # The empty mapping should not result in a feature being added - assert not tracks.graph.nodes[1].get("ellipse_axis_radii") + assert "ellipse_axis_radii" not in tracks.graph.node_attr_keys() def test_import_without_position_with_segmentation(self): """Test that position can be omitted when segmentation is provided. @@ -656,8 +658,8 @@ def test_import_without_position_with_segmentation(self): assert tracks is not None # Position should be computed from segmentation centroids - assert "pos" in tracks.graph.nodes[1] - pos_1 = tracks.graph.nodes[1]["pos"] + assert "pos" in tracks.graph.node_attr_keys() + pos_1 = tracks.graph[1]["pos"] # Centroid of 3x3 region at [2:5, 2:5] is approximately [3, 3] np.testing.assert_array_almost_equal(pos_1, [3.0, 3.0], decimal=0) diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index cfc860e1..cb91e735 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -1,4 +1,5 @@ import numpy as np +import polars as pl import pytest import zarr @@ -13,7 +14,6 @@ def test_export_to_geff( get_tracks, get_graph, - get_segmentation, ndim, with_seg, is_solution, @@ -33,23 +33,24 @@ def test_export_to_geff( if pos_attr_type is list: # For split pos, we need to manually create tracks since get_tracks # doesn't support this - graph = get_graph(ndim, with_features="computed") - segmentation = get_segmentation(ndim) if with_seg else None + graph_type = "segmentation" if with_seg else "position" + graph = get_graph(ndim, with_features=graph_type) # Determine position attribute keys based on dimensions pos_keys = ["y", "x"] if ndim == 3 else ["z", "y", "x"] # Split the composite position attribute into separate attributes - for node in graph.nodes(): - pos = graph.nodes[node]["pos"] + for key in pos_keys: + graph.add_node_attr_key(key, default_value=0.0, dtype=pl.Float64) + for node in graph.node_ids(): + pos = graph[node]["pos"] for i, key in enumerate(pos_keys): - graph.nodes[node][key] = pos[i] - del graph.nodes[node]["pos"] + graph[node][key] = pos[i] + graph.remove_node_attr_key("pos") # Create Tracks with split position attributes # Features like area, track_id will be auto-detected from the graph tracks_cls = SolutionTracks if is_solution else Tracks tracks = tracks_cls( graph, - segmentation=segmentation, time_attr="t", pos_attr=pos_keys, tracklet_attr="track_id", @@ -58,13 +59,17 @@ def test_export_to_geff( else: # Use get_tracks fixture for the simple case tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=is_solution) - export_to_geff(tracks, tmp_path) - z = zarr.open((tmp_path / "tracks").as_posix(), mode="r") + + # Export to subdirectory to avoid conflicts with database files in tmp_path + export_dir = tmp_path / "export" + export_dir.mkdir() + export_to_geff(tracks, export_dir) + z = zarr.open((export_dir / "tracks").as_posix(), mode="r") assert isinstance(z, zarr.Group) # Check that segmentation was saved (only when using segmentation) if with_seg: - seg_path = tmp_path / "segmentation" + seg_path = export_dir / "segmentation" seg_zarr = zarr.open(str(seg_path), mode="r") assert isinstance(seg_zarr, zarr.Array) np.testing.assert_array_equal(seg_zarr[:], tracks.segmentation) @@ -122,7 +127,7 @@ def test_export_to_geff( seg_zarr = zarr.open(str(seg_path), mode="r") assert isinstance(seg_zarr, zarr.Array) - filtered_seg = tracks.segmentation.copy() + filtered_seg = np.asarray(tracks.segmentation).copy() mask = np.isin(filtered_seg, [1, 3, 4, 6]) filtered_seg[~mask] = 0 np.testing.assert_array_equal(seg_zarr[:], filtered_seg) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index 35a713f4..f0b17a75 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -253,15 +253,15 @@ def test_duplicate_values_in_name_map(valid_geff): store, _ = valid_geff # Duplicate values should be allowed - each standard key gets a copy of the data - name_map = {"time": "t", "pos": ["y", "x"], "seg_id": "t"} + node_name_map = {"time": "t", "pos": ["y", "x"], "seg_id": "t"} # Should not raise - seg_id maps to same source as time - tracks = import_from_geff(store, name_map) + tracks = import_from_geff(store, node_name_map) # Both time and seg_id should be present with same values - for node_id in tracks.graph.nodes(): + for node_id in tracks.graph.node_ids(): assert tracks.get_node_attr(node_id, "seg_id") == tracks.get_node_attr( - node_id, "time" + node_id, "t" ) @@ -313,11 +313,11 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation, assert hasattr(tracks, "segmentation") assert tracks.segmentation.shape == valid_segmentation.shape # Get last node by ID (don't rely on iteration order) - last_node = max(tracks.graph.nodes()) + last_node = max(tracks.graph.node_ids()) # With composite pos, position is stored as an array - pos = tracks.graph.nodes[last_node]["pos"] + pos = tracks.graph[last_node]["pos"] coords = [ - tracks.graph.nodes[last_node]["time"], + tracks.graph[last_node]["t"], pos[0], # y pos[1], # x ] @@ -326,14 +326,14 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation, valid_segmentation[tuple(coords)] == 50 ) # in original segmentation, the pixel value is equal to seg_id assert ( - tracks.segmentation[tuple(coords)] == last_node + np.asarray(tracks.segmentation)[tuple(coords)] == last_node ) # test that the seg id has been relabeled # Check that only required/requested features are present, and that area is recomputed - data = tracks.graph.nodes[last_node] - assert "random_feature" in data - assert "random_feature2" not in data - assert "area" in data + data = tracks.graph[last_node] + assert "random_feature" in tracks.graph.node_attr_keys() + assert "random_feature2" not in tracks.graph.node_attr_keys() + assert "area" in tracks.graph.node_attr_keys() assert ( data["area"] == 0.01 ) # recomputed area values should be 1 pixel, so 0.01 after applying the scaling. @@ -352,9 +352,9 @@ def test_tracks_with_segmentation(valid_geff, invalid_geff, valid_segmentation, node_features=node_features, ) # Get last node by ID (don't rely on iteration order) - last_node = max(tracks.graph.nodes()) - data = tracks.graph.nodes[last_node] - assert "area" in data + last_node = max(tracks.graph.node_ids()) + data = tracks.graph[last_node] + assert "area" in tracks.graph.node_attr_keys() assert data["area"] == 21 # Test that import fails with ValueError when invalid seg_ids are provided. @@ -417,7 +417,7 @@ def test_node_features_compute_vs_load(valid_geff, valid_segmentation, tmp_path) Features not in the geff can still be computed if marked True. """ store, _ = valid_geff - name_map = { + node_name_map = { "time": "t", "pos": ["y", "x"], # Composite position mapping "seg_id": "seg_id", @@ -437,7 +437,7 @@ def test_node_features_compute_vs_load(valid_geff, valid_segmentation, tmp_path) tracks = import_from_geff( store, - name_map, + node_name_map, segmentation_path=valid_segmentation_path, scale=scale, node_features=node_features, @@ -448,12 +448,12 @@ def test_node_features_compute_vs_load(valid_geff, valid_segmentation, tmp_path) assert key in tracks.features # Get last node by ID (don't rely on iteration order) - max_node_id = max(tracks.graph.nodes()) - data = tracks.graph.nodes[max_node_id] + max_node_id = max(tracks.graph.node_ids()) + data = tracks.graph[max_node_id] # All requested features should be present for key in feature_keys: - assert key in data + assert data[key] is not None # Verify computed values (1 pixel = 0.01 after scaling) # Original geff had area=21 for last node diff --git a/tests/import_export/test_import_segmentation.py b/tests/import_export/test_import_segmentation.py index 6c1f51fb..8f0999c5 100644 --- a/tests/import_export/test_import_segmentation.py +++ b/tests/import_export/test_import_segmentation.py @@ -1,6 +1,5 @@ """Tests for _import_segmentation module.""" -import networkx as nx import numpy as np import tifffile @@ -8,6 +7,7 @@ load_segmentation, relabel_segmentation, ) +from funtracks.utils.tracksdata_utils import create_empty_graphview_graph class TestLoadSegmentation: @@ -46,15 +46,15 @@ def test_basic_relabeling(self): seg[1, 2, 2] = 20 # seg_id 20 at t=1 # Create graph with node_ids 1, 2 - graph = nx.DiGraph() - graph.add_node(1) - graph.add_node(2) + graph = create_empty_graphview_graph() + graph.add_node(index=1, attrs={"t": 0, "solution": 1}) + graph.add_node(index=2, attrs={"t": 1, "solution": 1}) node_ids = np.array([1, 2]) seg_ids = np.array([10, 20]) time_values = np.array([0, 1]) - result = relabel_segmentation(seg, graph, node_ids, seg_ids, time_values) + result, graph = relabel_segmentation(seg, graph, node_ids, seg_ids, time_values) # seg_id 10 -> node_id 1, seg_id 20 -> node_id 2 assert result[0, 1, 1] == 1 @@ -70,15 +70,15 @@ def test_relabeling_with_node_id_zero(self): seg[1, 2, 2] = 20 # seg_id 20 at t=1 # Create graph with node_ids 0, 1 (includes 0!) - graph = nx.DiGraph() - graph.add_node(0) - graph.add_node(1) + graph = create_empty_graphview_graph() + graph.add_node(index=0, attrs={"t": 0, "solution": 1}) + graph.add_node(index=1, attrs={"t": 1, "solution": 1}) node_ids = np.array([0, 1]) seg_ids = np.array([10, 20]) time_values = np.array([0, 1]) - result = relabel_segmentation(seg, graph, node_ids, seg_ids, time_values) + result, graph = relabel_segmentation(seg, graph, node_ids, seg_ids, time_values) # node_ids should be offset by 1: 0->1, 1->2 # seg_id 10 -> node_id 1 (was 0), seg_id 20 -> node_id 2 (was 1) @@ -86,9 +86,9 @@ def test_relabeling_with_node_id_zero(self): assert result[1, 2, 2] == 2 # Graph should also be relabeled - assert 1 in graph.nodes() - assert 2 in graph.nodes() - assert 0 not in graph.nodes() + assert graph.has_node(1) + assert graph.has_node(2) + assert not graph.has_node(0) def test_no_relabeling_needed_same_ids(self): """Test when seg_ids equal node_ids (relabeling still applies mapping).""" @@ -97,15 +97,15 @@ def test_no_relabeling_needed_same_ids(self): seg[0, 1, 1] = 1 seg[1, 2, 2] = 2 - graph = nx.DiGraph() - graph.add_node(1) - graph.add_node(2) + graph = create_empty_graphview_graph() + graph.add_node(index=1, attrs={"t": 0, "solution": 1}) + graph.add_node(index=2, attrs={"t": 1, "solution": 1}) node_ids = np.array([1, 2]) seg_ids = np.array([1, 2]) # Same as node_ids time_values = np.array([0, 1]) - result = relabel_segmentation(seg, graph, node_ids, seg_ids, time_values) + result, graph = relabel_segmentation(seg, graph, node_ids, seg_ids, time_values) # Should still produce valid output (identity mapping) assert result[0, 1, 1] == 1 @@ -119,14 +119,16 @@ def test_multiple_nodes_same_timepoint(self): seg[0, 2, 2] = 20 seg[0, 3, 3] = 30 - graph = nx.DiGraph() - graph.add_nodes_from([1, 2, 3]) + graph = create_empty_graphview_graph() + graph.add_node(index=1, attrs={"t": 0, "solution": 1}) + graph.add_node(index=2, attrs={"t": 0, "solution": 1}) + graph.add_node(index=3, attrs={"t": 0, "solution": 1}) node_ids = np.array([1, 2, 3]) seg_ids = np.array([10, 20, 30]) time_values = np.array([0, 0, 0]) - result = relabel_segmentation(seg, graph, node_ids, seg_ids, time_values) + result, graph = relabel_segmentation(seg, graph, node_ids, seg_ids, time_values) assert result[0, 1, 1] == 1 assert result[0, 2, 2] == 2 diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index b48e6666..8112cc60 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -1,13 +1,12 @@ import json +import shutil from collections.abc import Sequence +from pathlib import Path import pytest -from networkx.utils import graphs_equal from numpy.testing import assert_array_almost_equal -from funtracks.data_model import Tracks from funtracks.import_export._v1_format import ( - _save_v1_tracks, delete_tracks, load_v1_tracks, ) @@ -21,13 +20,15 @@ def test_save_load( with_seg, ndim, is_solution, - tmp_path, ): tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=is_solution) - _save_v1_tracks(tracks, tmp_path) - loaded = load_v1_tracks(tmp_path, solution=is_solution) - assert loaded.ndim == tracks.ndim + data_path = Path( + f"tests/data/format_v1/test_save_load_{is_solution}_{ndim}_{with_seg}_0" + ) + + loaded = load_v1_tracks(data_path, solution=is_solution) + assert loaded.ndim == ndim # Check feature keys and important properties match (allow tuple vs list diff) assert loaded.features.time_key == tracks.features.time_key assert loaded.features.position_key == tracks.features.position_key @@ -71,7 +72,14 @@ def test_save_load( else: assert loaded.segmentation is None - assert graphs_equal(loaded.graph, tracks.graph) + # graphs_equal doesn't exist for TracksData, so we check properties + assert set(loaded.graph.node_attr_keys()) == set(tracks.graph.node_attr_keys()) + assert set(loaded.graph.edge_attr_keys()) == set(tracks.graph.edge_attr_keys()) + assert loaded.graph.num_nodes() == tracks.graph.num_nodes() + assert loaded.graph.num_edges() == tracks.graph.num_edges() + assert set(loaded.graph.node_ids()) == set(tracks.graph.node_ids()) + # edge_ids dont matter, only the actual edges: + assert sorted(loaded.graph.edge_list()) == sorted(tracks.graph.edge_list()) @pytest.mark.parametrize("with_seg", [True, False]) @@ -84,19 +92,32 @@ def test_delete( is_solution, tmp_path, ): + reference_path = Path( + f"tests/data/format_v1/test_save_load_{is_solution}_{ndim}_{with_seg}_0" + ) + + # Copy reference data to temporary location tracks_path = tmp_path / "test_tracks" - tracks = get_tracks(ndim=ndim, with_seg=with_seg, is_solution=is_solution) - _save_v1_tracks(tracks, tracks_path) + shutil.copytree(reference_path, tracks_path) + + # Delete the copy delete_tracks(tracks_path) with pytest.raises(StopIteration): next(tmp_path.iterdir()) # for backward compatibility -def test_load_without_features(tmp_path, graph_2d_with_computed_features): - tracks = Tracks(graph_2d_with_computed_features, ndim=3) +def test_load_without_features(tmp_path, graph_2d_with_segmentation): + reference_path = Path(f"tests/data/format_v1/test_save_load_{True}_{3}_{True}_0") + + # Copy reference data to temporary location tracks_path = tmp_path / "test_tracks" - _save_v1_tracks(tracks, tracks_path) + shutil.copytree(reference_path, tracks_path) + + # Load the original data first to verify it loads correctly + load_v1_tracks(tracks_path, solution=True) + + # Modify the copy to test backward compatibility attrs_path = tracks_path / "attrs.json" with open(attrs_path) as f: attrs = json.load(f) @@ -107,6 +128,7 @@ def test_load_without_features(tmp_path, graph_2d_with_computed_features): with open(attrs_path, "w") as f: json.dump(attrs, f) + # Load the modified data to test old format compatibility imported_tracks = load_v1_tracks(tracks_path) assert imported_tracks.features.time_key == "time" assert imported_tracks.features.position_key == "pos" diff --git a/tests/import_export/test_name_mapping.py b/tests/import_export/test_name_mapping.py index 3d7d3b46..fe436ed3 100644 --- a/tests/import_export/test_name_mapping.py +++ b/tests/import_export/test_name_mapping.py @@ -148,7 +148,7 @@ def test_empty_available_props(self): class TestMatchDisplayNamesExact: """Test exact matching between properties and feature display names.""" - def test_exact_display_name_match(self): + def test_exact_display_name_match(self) -> None: """Test exact matching with display names.""" available_props = ["Area", "Circularity", "time"] display_name_to_key = { @@ -164,7 +164,7 @@ def test_exact_display_name_match(self): assert mapping == {"area": "Area", "circularity": "Circularity"} assert remaining == ["time"] - def test_no_matches(self): + def test_no_matches(self) -> None: """Test when no properties match display names.""" available_props = ["t", "x", "y"] display_name_to_key = { @@ -180,14 +180,14 @@ def test_no_matches(self): assert mapping == {} assert remaining == ["t", "x", "y"] - def test_empty_inputs(self): + def test_empty_inputs(self) -> None: """Test with empty inputs.""" mapping: dict = {} remaining = _match_display_names_exact([], {}, mapping) assert mapping == {} assert remaining == [] - def test_case_sensitive(self): + def test_case_sensitive(self) -> None: """Test that exact matching is case-sensitive.""" available_props = ["area", "AREA"] display_name_to_key = {"Area": ("area", 0)} @@ -200,7 +200,7 @@ def test_case_sensitive(self): assert mapping == {} # Neither "area" nor "AREA" matches "Area" exactly assert set(remaining) == {"area", "AREA"} - def test_multi_value_feature(self): + def test_multi_value_feature(self) -> None: """Test matching multi-value features by value_names.""" available_props = ["major_axis", "minor_axis", "Area"] display_name_to_key = { @@ -225,7 +225,7 @@ def test_multi_value_feature(self): class TestMatchDisplayNamesFuzzy: """Test fuzzy matching between properties and feature display names.""" - def test_case_insensitive_match(self): + def test_case_insensitive_match(self) -> None: """Test case-insensitive fuzzy matching.""" available_props = ["area", "CIRC"] display_name_to_key = { @@ -239,7 +239,7 @@ def test_case_insensitive_match(self): assert "area" in mapping assert "circularity" in mapping - def test_abbreviation_match(self): + def test_abbreviation_match(self) -> None: """Test matching abbreviations to display names.""" available_props = ["Circ", "Ecc"] display_name_to_key = { @@ -253,7 +253,7 @@ def test_abbreviation_match(self): assert "circularity" in mapping assert "eccentricity" in mapping - def test_no_matches(self): + def test_no_matches(self) -> None: """Test when no fuzzy matches found.""" available_props = ["xyz", "abc"] display_name_to_key = {"Area": ("area", 0)} @@ -266,7 +266,7 @@ def test_no_matches(self): assert mapping == {} assert set(remaining) == {"xyz", "abc"} - def test_empty_available_props(self): + def test_empty_available_props(self) -> None: """Test with empty available properties.""" mapping: dict = {} remaining = _match_display_names_fuzzy([], {"Area": ("area", 0)}, mapping) @@ -274,7 +274,7 @@ def test_empty_available_props(self): assert mapping == {} assert remaining == [] - def test_custom_cutoff(self): + def test_custom_cutoff(self) -> None: """Test with custom cutoff value.""" available_props = ["Ar"] display_name_to_key = {"Area": ("area", 0)} @@ -293,7 +293,7 @@ def test_custom_cutoff(self): ) assert mapping_high == {} - def test_multi_value_feature(self): + def test_multi_value_feature(self) -> None: """Test fuzzy matching multi-value features by value_names.""" available_props = ["Major_Axis", "Minor_Axis", "area"] display_name_to_key = { diff --git a/tests/user_actions/__init__.py b/tests/user_actions/__init__.py new file mode 100644 index 00000000..5c7fde60 --- /dev/null +++ b/tests/user_actions/__init__.py @@ -0,0 +1,2 @@ +# This file makes the tests/user_actions directory a Python package +# to support relative imports diff --git a/tests/user_actions/test_user_actions_force.py b/tests/user_actions/test_user_actions_force.py index bae5d5f4..1096b958 100644 --- a/tests/user_actions/test_user_actions_force.py +++ b/tests/user_actions/test_user_actions_force.py @@ -13,9 +13,9 @@ def test_user_force_add_downstream(get_tracks): attrs = {"t": 2, "track_id": 1, "pos": [3, 4]} UserAddNode(tracks, node=7, attributes=attrs, force=True) assert tracks.get_track_id(7) == 1 - assert (1, 2) not in tracks.graph.edges - assert (1, 3) not in tracks.graph.edges - assert (1, 7) in tracks.graph.edges + assert [1, 2] not in tracks.graph.edge_list() + assert [1, 3] not in tracks.graph.edge_list() + assert [1, 7] in tracks.graph.edge_list() def test_user_force_add_upstream(get_tracks): @@ -28,9 +28,9 @@ def test_user_force_add_upstream(get_tracks): attrs = {"t": 0, "track_id": 3, "pos": [3, 4]} UserAddNode(tracks, node=7, attributes=attrs, force=True) assert tracks.get_track_id(7) == 3 - assert (1, 2) in tracks.graph.edges # still there - assert (1, 3) not in tracks.graph.edges # should be removed - assert (7, 3) in tracks.graph.edges # new forced edge + assert [1, 2] in tracks.graph.edge_list() # still there + assert [1, 3] not in tracks.graph.edge_list() # should be removed + assert [7, 3] in tracks.graph.edge_list() # new forced edge def test_auto_assign_new_track_id(get_tracks): @@ -44,5 +44,5 @@ def test_auto_assign_new_track_id(get_tracks): attrs = {"t": 1, "track_id": 2, "pos": [3, 4]} # combination exists already UserAddNode(tracks, node=7, attributes=attrs) - assert 7 in tracks.graph.nodes + assert tracks.graph.has_node(7) assert tracks.get_track_id(7) == 6 # new assigned track id diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py index 852558e0..84f4c5e1 100644 --- a/tests/user_actions/test_user_add_delete_edge.py +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -1,4 +1,5 @@ import pytest +import tracksdata as td from funtracks.exceptions import InvalidActionError from funtracks.user_actions import UserAddEdge, UserDeleteEdge @@ -123,8 +124,16 @@ def test_delete_missing_edge(get_tracks): def test_delete_edge_triple_div(get_tracks): tracks = get_tracks(ndim=3, with_seg=True, is_solution=True) - tracks.graph.add_edge(1, 6) + attrs = {} + attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = 1 + attrs["iou"] = 0.9 + + tracks.graph.add_edge( + source_id=1, + target_id=6, + attrs=attrs, + ) with pytest.raises( - InvalidActionError, match="Expected degree of 0 or 1 after removing edge" + InvalidActionError, match="Expected degree of 0 or 1 after removing edge, got 2" ): UserDeleteEdge(tracks, (1, 6)) diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py index 55f1385d..82d65d42 100644 --- a/tests/user_actions/test_user_add_delete_node.py +++ b/tests/user_actions/test_user_add_delete_node.py @@ -46,7 +46,7 @@ def test_user_add_node(self, get_tracks, ndim, with_seg): "t": time, } if with_seg: - seg_copy = tracks.segmentation.copy() + seg_copy = np.asarray(tracks.segmentation).copy() if ndim == 3: seg_copy[time, position[0], position[1]] = node_id else: diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py index 0d228a14..479f0fd9 100644 --- a/tests/user_actions/test_user_update_segmentation.py +++ b/tests/user_actions/test_user_update_segmentation.py @@ -114,9 +114,7 @@ def test_user_erase_seg(self, get_tracks, ndim): # remove all pixels pixels_to_remove = orig_pixels - # set the pixels in the array first - # (to reflect that the user directly changes the segmentation array) - tracks.set_pixels(pixels_to_remove, 0) + # setting of pixels no longer necessary, done in UpdateNodeSeg action = UserUpdateSegmentation( tracks, new_value=0, @@ -125,7 +123,6 @@ def test_user_erase_seg(self, get_tracks, ndim): ) assert not tracks.graph.has_node(node_id) - tracks.set_pixels(pixels_to_remove, node_id) inverse = action.inverse() assert tracks.graph.has_node(node_id) self.pixel_equals(tracks.get_pixels(node_id), orig_pixels) @@ -133,7 +130,6 @@ def test_user_erase_seg(self, get_tracks, ndim): assert tracks.get_node_attr(node_id, "area") == orig_area assert tracks.get_edge_attr(edge, iou_key) == pytest.approx(orig_iou, abs=0.01) - tracks.set_pixels(pixels_to_remove, 0) inverse.inverse() assert not tracks.graph.has_node(node_id) @@ -155,14 +151,13 @@ def test_user_add_seg(self, get_tracks, ndim): assert not tracks.graph.has_node(node_id) assert np.sum(tracks.segmentation == node_id) == 0 - tracks.set_pixels(pixels_to_add, node_id) action = UserUpdateSegmentation( tracks, new_value=node_id, updated_pixels=[(pixels_to_add, 0)], current_track_id=10, ) - assert np.sum(tracks.segmentation == node_id) == len(pixels_to_add[0]) + assert np.sum(np.asarray(tracks.segmentation) == node_id) == len(pixels_to_add[0]) assert tracks.graph.has_node(node_id) assert tracks.get_position(node_id) == position assert tracks.get_node_attr(node_id, "area") == area diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..9782ccf2 --- /dev/null +++ b/tests/utils/__init__.py @@ -0,0 +1,2 @@ +# This file makes the tests/utils directory a Python package +# to support relative imports diff --git a/tests/utils/test_tracksdata_utils.py b/tests/utils/test_tracksdata_utils.py new file mode 100644 index 00000000..6f36ffd2 --- /dev/null +++ b/tests/utils/test_tracksdata_utils.py @@ -0,0 +1,135 @@ +"""Tests for tracksdata utility functions.""" + +import numpy as np +import pytest + +from funtracks.utils.tracksdata_utils import pixels_to_td_mask, td_mask_to_pixels + +# Import from conftest +from ..conftest import ( + make_2d_disk_mask, + make_2d_square_mask, + make_3d_cube_mask, + make_3d_sphere_mask, +) + + +@pytest.mark.parametrize( + "mask_func,ndim", + [ + (lambda: make_2d_disk_mask(center=(50, 50), radius=20), 3), + (lambda: make_2d_disk_mask(center=(25, 75), radius=10), 3), + (lambda: make_2d_square_mask(start_corner=(10, 10), width=5), 3), + (lambda: make_3d_sphere_mask(center=(50, 50, 50), radius=20), 4), + (lambda: make_3d_sphere_mask(center=(25, 75, 30), radius=15), 4), + (lambda: make_3d_cube_mask(start_corner=(10, 10, 10), width=5), 4), + ], +) +def test_mask_pixels_roundtrip(mask_func, ndim): + """Test that mask -> pixels -> mask roundtrip preserves the mask.""" + # Create original mask + original_mask = mask_func() + time = 5 # Arbitrary time point + + # Convert mask to pixels + pixels = td_mask_to_pixels(original_mask, time=time, ndim=ndim) + + # Verify pixel format + assert len(pixels) == ndim # Should have ndim arrays + assert len(pixels[0]) == len(pixels[1]) # All arrays same length + assert np.all(pixels[0] == time) # Time should be constant + + # Convert pixels back to mask + reconstructed_mask, area = pixels_to_td_mask( + pixels, ndim=ndim, scale=[1 for _ in range(ndim)], include_area=True + ) + + # Verify the reconstructed mask matches the original + assert np.array_equal(reconstructed_mask.bbox, original_mask.bbox), ( + "Bounding boxes should match" + ) + assert np.array_equal(reconstructed_mask.mask, original_mask.mask), ( + "Mask arrays should match" + ) + assert area == np.sum(original_mask.mask), "Area should match pixel count" + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_mask_pixels_roundtrip_with_scale(ndim): + """Test mask->pixels->mask roundtrip with scale factors.""" + # Create mask + if ndim == 3: + mask = make_2d_disk_mask(center=(40, 60), radius=15) + scale = [1.0, 2.0, 3.0] # time, y, x scales + else: + mask = make_3d_sphere_mask(center=(40, 60, 30), radius=12) + scale = [1.0, 2.0, 3.0, 4.0] # time, z, y, x scales + + time = 3 + + # Convert mask to pixels + pixels = td_mask_to_pixels(mask, time=time, ndim=ndim) + + # Convert back with scale + reconstructed_mask, scaled_area = pixels_to_td_mask( + pixels, ndim=ndim, scale=scale, include_area=True + ) + + # Verify mask structure is preserved + assert np.array_equal(reconstructed_mask.bbox, mask.bbox) + assert np.array_equal(reconstructed_mask.mask, mask.mask) + + # Verify area is scaled correctly + expected_area = np.sum(mask.mask) * np.prod(scale[1:]) + assert np.isclose(scaled_area, expected_area), ( + f"Scaled area {scaled_area} should match expected {expected_area}" + ) + + +def test_td_mask_to_pixels_empty_mask(): + """Test converting an empty mask to pixels.""" + from tracksdata.nodes._mask import Mask + + # Create a truly empty mask (all False) + empty_mask_array = np.zeros((2, 2), dtype=bool) + empty_bbox = np.array([10, 10, 12, 12]) + empty_mask = Mask(empty_mask_array, bbox=empty_bbox) + + pixels = td_mask_to_pixels(empty_mask, time=1, ndim=3) + + # Should return empty arrays + assert len(pixels) == 3 + assert len(pixels[0]) == 0 # No pixels + assert len(pixels[1]) == 0 + assert len(pixels[2]) == 0 + + +@pytest.mark.parametrize("ndim", [3, 4]) +def test_pixels_coordinate_offset(ndim): + """Test that bbox offset is correctly applied in pixel coordinates.""" + # Create a mask at a non-zero position + if ndim == 3: + mask = make_2d_square_mask(start_corner=(20, 30), width=3) + expected_bbox = np.array([20, 30, 23, 33]) + else: + mask = make_3d_cube_mask(start_corner=(20, 30, 40), width=3) + expected_bbox = np.array([20, 30, 40, 23, 33, 43]) + + assert np.array_equal(mask.bbox, expected_bbox) + + # Convert to pixels + pixels = td_mask_to_pixels(mask, time=7, ndim=ndim) + + # Verify pixel coordinates are in global space (not local) + if ndim == 3: + assert np.min(pixels[1]) == 20 # min y + assert np.max(pixels[1]) == 22 # max y + assert np.min(pixels[2]) == 30 # min x + assert np.max(pixels[2]) == 32 # max x + else: + assert np.min(pixels[1]) == 20 # min z + assert np.max(pixels[1]) == 22 # max z + assert np.min(pixels[2]) == 30 # min y + assert np.max(pixels[2]) == 32 # max y + assert np.min(pixels[3]) == 40 # min x + assert np.max(pixels[3]) == 42 # max x