Skip to content

Commit 8fe1f47

Browse files
Merge pull request #165 from softmatterlab/jp/magik-update
Update graph generators
2 parents 566c9b1 + 10b850e commit 8fe1f47

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

deeptrack/models/gnns/graphs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def GetEdge(
1616
end: int,
1717
radius: int,
1818
parenthood: pd.DataFrame,
19+
columns = [],
1920
**kwargs,
2021
):
2122
"""
@@ -79,7 +80,7 @@ def GetEdge(
7980
edges.append(combdf)
8081
# Concatenate the dataframes in a single
8182
# dataframe for the whole set of edges
82-
edgedf = pd.concat(edges)
83+
edgedf = pd.concat(edges) if len(edges) > 0 else pd.DataFrame(columns=columns)
8384

8485
# Merge columns contaning the labels into a single column
8586
# of numpy arrays, i.e., label = [label_x, label_y]
@@ -120,6 +121,7 @@ def EdgeExtractor(nodesdf, nofframes=3, **kwargs):
120121
"""
121122
# Create a copy of the dataframe to avoid overwriting
122123
df = nodesdf.copy()
124+
columns = df.columns
123125

124126
edgedfs = []
125127
sets = np.unique(df["set"])
@@ -140,7 +142,7 @@ def EdgeExtractor(nodesdf, nofframes=3, **kwargs):
140142
window = [elem for elem in window if elem <= df_set["frame"].max()]
141143

142144
# Compute the edges for each frames window
143-
edgedf = GetEdge(df_set, start=window[0], end=window[-1], **kwargs)
145+
edgedf = GetEdge(df_set, start=window[0], end=window[-1], columns=columns, **kwargs)
144146
edgedf["set"] = setid
145147
edgedfs.append(edgedf)
146148

deeptrack/test/test_generators.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from .. import generators
99
from ..optics import Fluorescence
1010
from ..scatterers import PointParticle
11+
from ..models import gnns
1112
import numpy as np
12-
13+
import pandas as pd
1314

1415
class TestGenerators(unittest.TestCase):
1516
def test_Generator(self):
@@ -154,7 +155,37 @@ def get_particle_position(result):
154155
# a = generator[idx]
155156

156157
# [self.assertLess(d[-1], 8) for d in generator.data]
158+
157159

160+
def test_GraphGenerator(self):
161+
frame = np.arange(10)
162+
centroid = np.random.normal(0.5, 0.1, (10, 2))
163+
164+
df = pd.DataFrame(
165+
{
166+
'frame': frame,
167+
'centroid-0': centroid[:, 0],
168+
'centroid-1': centroid[:, 1],
169+
'label': 0,
170+
'set': 0,
171+
'solution': 0.0
172+
}
173+
)
174+
# remove consecutive frames
175+
df = df[~df["frame"].isin([3, 4, 5])]
176+
177+
generator = gnns.generators.GraphGenerator(
178+
nodesdf=df,
179+
properties=["centroid"],
180+
min_data_size=8,
181+
max_data_size=9,
182+
batch_size=8,
183+
feature_function=gnns.augmentations.GetGlobalFeature,
184+
radius=0.2,
185+
nofframes=3,
186+
output_type="edges"
187+
)
188+
self.assertIsInstance(generator, gnns.generators.ContinuousGraphGenerator)
158189

159190
if __name__ == "__main__":
160191
unittest.main()

0 commit comments

Comments
 (0)