Skip to content

Commit d0e9458

Browse files
committed
namedtuple to rescue numba jit
1 parent dcd875b commit d0e9458

File tree

9 files changed

+100
-78
lines changed

9 files changed

+100
-78
lines changed

openptv_python/parameters.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Parameters for OpenPTV-Python."""
2+
from collections import namedtuple
23
from dataclasses import asdict, dataclass, field
34
from pathlib import Path
45
from typing import List, Tuple
@@ -7,6 +8,21 @@
78

89
from openptv_python.constants import TR_MAX_CAMS
910

11+
TrackParTuple = namedtuple('TrackParTuple',
12+
['dvxmin',
13+
'dvxmax',
14+
'dvymin',
15+
'dvymax',
16+
'dvzmin',
17+
'dvzmax',
18+
'dangle',
19+
'dacc',
20+
'add',
21+
'dsumg',
22+
'dn',
23+
'dnx',
24+
'dny'])
25+
1026

1127
@dataclass
1228
class Parameters:
@@ -193,24 +209,6 @@ class TrackPar(Parameters):
193209
dny: float = 0.0
194210

195211

196-
# def to_dict(self):
197-
# """Convert TrackPar instance to a dictionary."""
198-
# return {
199-
# 'dvxmax': self.dvxmax,
200-
# 'dvxmin': self.dvxmin,
201-
# 'dvymax': self.dvymax,
202-
# 'dvymin': self.dvymin,
203-
# 'dvzmax': self.dvzmax,
204-
# 'dvzmin': self.dvzmin,
205-
# 'dangle': self.dangle,
206-
# 'dacc': self.dacc,
207-
# 'add': self.add,
208-
# 'dsumg': self.dsumg,
209-
# 'dn': self.dn,
210-
# 'dnx': self.dnx,
211-
# 'dny': self.dny,
212-
# }
213-
214212
@classmethod
215213
def from_file(cls, filename: Path):
216214
"""Read tracking parameters from file and return TrackPar object.
@@ -315,6 +313,24 @@ def compare_track_par(t1: TrackPar, t2: TrackPar) -> bool:
315313
return all(getattr(t1, field) == getattr(t2, field) for field in t1.__annotations__)
316314

317315

316+
def convert_track_par_to_tuple(track_par: TrackPar) -> TrackParTuple:
317+
"""Convert TrackPar object to TrackParTuple object."""
318+
return TrackParTuple(track_par.dvxmin,
319+
track_par.dvxmax,
320+
track_par.dvymin,
321+
track_par.dvymax,
322+
track_par.dvzmin,
323+
track_par.dvzmax,
324+
track_par.dangle,
325+
track_par.dacc,
326+
track_par.add,
327+
track_par.dsumg,
328+
track_par.dn,
329+
track_par.dnx,
330+
track_par.dny)
331+
332+
333+
318334
@dataclass
319335
class VolumePar(Parameters):
320336
"""Volume parameters."""

openptv_python/track.py

Lines changed: 5 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222
from .imgcoord import img_coord
2323
from .orientation import point_position
24-
from .parameters import ControlPar, TrackPar
24+
from .parameters import ControlPar, TrackParTuple, convert_track_par_to_tuple
2525
from .tracking_frame_buf import Frame, Pathinfo, Target
2626
from .tracking_run import TrackingRun
2727
from .trafo import dist_to_flat, metric_to_pixel, pixel_to_metric
@@ -188,8 +188,8 @@ def predict(prev_pos, curr_pos, output):
188188
output[0] = 2 * curr_pos[0] - prev_pos[0]
189189
output[1] = 2 * curr_pos[1] - prev_pos[1]
190190

191-
192-
def pos3d_in_bounds(pos: np.ndarray, bounds: TrackPar) -> bool:
191+
@njit(cache=True, fastmath=True, nogil=True)
192+
def pos3d_in_bounds(pos: np.ndarray, bounds: TrackParTuple) -> bool:
193193
"""Check that all components of a pos3d are in their respective bounds.
194194
195195
taken from a track_par object.
@@ -211,44 +211,6 @@ def pos3d_in_bounds(pos: np.ndarray, bounds: TrackPar) -> bool:
211211
)
212212

213213

214-
# def angle_acc(
215-
# start: np.ndarray, pred: np.ndarray, cand: np.ndarray
216-
# ) -> Tuple[float, float]:
217-
# """Calculate the angle between the (1st order) numerical velocity vectors.
218-
219-
# to the predicted next_frame position and to the candidate actual position. The
220-
# angle is calculated in [gon], see [1]. The predicted position is the
221-
# position if the particle continued at current velocity.
222-
223-
# Arguments:
224-
# ---------
225-
# start -- vec3d, the particle start position
226-
# pred -- vec3d, predicted position
227-
# cand -- vec3d, possible actual position
228-
229-
# Returns:
230-
# -------
231-
# angle -- float, the angle between the two velocity vectors, [gon]
232-
# acc -- float, the 1st-order numerical acceleration embodied in the deviation from prediction.
233-
# """
234-
# v0 = pred - start
235-
# v1 = cand - start
236-
237-
# acc = math.dist(v0, v1)
238-
# # acc = np.linalg.norm(v0 - v1)
239-
240-
# if np.all(v0 == -v1):
241-
# angle = 200
242-
# elif np.all(v0 == v1):
243-
# angle = 0
244-
# else:
245-
# angle = float((200.0 / math.pi) * math.acos(
246-
# math.fsum([v0[i] * v1[i] for i in range(3)])
247-
# / (math.dist(start, pred) * math.dist(start, cand)))
248-
# )
249-
250-
# return angle, acc
251-
252214
@njit(float64[:](float64[:], float64[:], float64[:]), cache=True, fastmath=True, nogil=True, parallel=True)
253215
def angle_acc(
254216
start: np.ndarray, pred: np.ndarray, cand: np.ndarray
@@ -441,7 +403,7 @@ def candsearch_in_pix_rest(
441403

442404

443405
def searchquader(
444-
point: np.ndarray, tpar: TrackPar, cpar: ControlPar, cal: List[Calibration]
406+
point: np.ndarray, tpar: TrackParTuple, cpar: ControlPar, cal: List[Calibration]
445407
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
446408
"""Calculate the search volume in image space."""
447409
mins = np.array([tpar.dvxmin, tpar.dvymin, tpar.dvzmin])
@@ -819,7 +781,7 @@ def trackcorr_c_loop(run_info, step):
819781

820782
fb = run_info.fb
821783
cal = run_info.cal
822-
tpar = run_info.tpar
784+
tpar = convert_track_par_to_tuple(run_info.tpar)
823785
vpar = run_info.vpar
824786
cpar = run_info.cpar
825787
curr_targets = fb.buf[1].targets

openptv_python/tracking_run.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
from .parameters import (
1212
ControlPar,
1313
SequencePar,
14-
TrackPar,
14+
TrackParTuple,
1515
VolumePar,
16+
convert_track_par_to_tuple,
1617
read_control_par,
1718
read_sequence_par,
1819
read_track_par,
@@ -26,7 +27,7 @@ class TrackingRun:
2627

2728
fb: FrameBuf
2829
seq_par: SequencePar
29-
tpar: TrackPar
30+
tpar: TrackParTuple
3031
vpar: VolumePar
3132
cpar: ControlPar
3233
cal: List[Calibration]
@@ -40,7 +41,7 @@ class TrackingRun:
4041
def __init__(
4142
self,
4243
seq_par: SequencePar,
43-
tpar: TrackPar,
44+
tpar: TrackParTuple,
4445
vpar: VolumePar,
4546
cpar: ControlPar,
4647
buf_len: int,
@@ -113,7 +114,7 @@ def tr_new(
113114
"""Create a new tracking run from legacy files."""
114115
cpar = read_control_par(cpar_fname)
115116
seq_par = read_sequence_par(seq_par_fname, cpar.num_cams)
116-
tpar = read_track_par(tpar_fname)
117+
tpar = convert_track_par_to_tuple(read_track_par(tpar_fname))
117118
vpar = read_volume_par(vpar_fname)
118119

119120
tr = TrackingRun(

profile

443 KB
Binary file not shown.

tests/test_burgers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def test_burgers(self):
138138
10000.0,
139139
)
140140

141-
run.tpar.add = 1
141+
# run.tpar = run.tpar._replace(add=1)
142+
run.tpar = run.tpar._replace(add=1)
142143
print("changed add particle to", run.tpar.add)
143144

144145
track_forward_start(run)

tests/test_numba_namedtuple.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from collections import namedtuple
2+
3+
import numpy as np
4+
from numba import njit
5+
6+
# Define a namedtuple
7+
Point = namedtuple('Point', ['x', 'y'])
8+
9+
# Create a list of points
10+
points = [Point(x, y) for x in range(1000) for y in range(1000)]
11+
12+
# Calculate the distance between each point and the origin
13+
@njit
14+
def distance_to_origin(point):
15+
"""Calculate the distance between a point and the origin."""
16+
return np.sqrt(point.x**2 + point.y**2)
17+
18+
19+
# Calculate the distance between each point and the origin using the function
20+
distances = np.array([distance_to_origin(point) for point in points])
21+
22+
# Calculate the distance between each point and the origin manually
23+
expected_distances = np.array([np.sqrt(point.x**2 + point.y**2) for point in points])
24+
25+
# Compare the results
26+
np.testing.assert_allclose(distances, expected_distances)

tests/test_tracking.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from openptv_python.parameters import (
2121
ControlPar,
2222
TrackPar,
23+
convert_track_par_to_tuple,
2324
)
2425
from openptv_python.track import (
2526
Foundpix_dtype,
@@ -122,7 +123,8 @@ def test_pos3d_in_bounds(self):
122123
inside = np.array([1.0, -1.0, 0.0])
123124
outside = np.array([2.0, -0.8, 2.1])
124125

125-
bounds = TrackPar(
126+
bounds = convert_track_par_to_tuple(
127+
TrackPar(
126128
-2.0,
127129
2.0,
128130
-2.0,
@@ -137,6 +139,7 @@ def test_pos3d_in_bounds(self):
137139
0.0,
138140
0.0,
139141
)
142+
)
140143

141144
result = pos3d_in_bounds(inside, bounds)
142145

@@ -361,9 +364,11 @@ def test_searchquader(self):
361364

362365
# print(f"cpar = {self.cpar}")
363366

364-
tpar = TrackPar(
367+
tpar = convert_track_par_to_tuple(
368+
TrackPar(
365369
0.2, -0.2, 0.1, -0.1, 0.1, -0.1, 120, 0.4, 1, 0.0, 0.0, 0.0, 0.0
366370
)
371+
)
367372
xr, xl, yd, yu = searchquader(point, tpar, self.cpar, self.calib)
368373

369374
# print(f"xr = {xr}, xl = {xl}, yd = {yd}, yu = {yu}")
@@ -379,9 +384,11 @@ def test_searchquader(self):
379384

380385
# Let's test with just one camera to check borders
381386
self.cpar.num_cams = 1
382-
tpar1 = TrackPar(
387+
tpar1 = convert_track_par_to_tuple(
388+
TrackPar(
383389
0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 120, 0.4, 1, 0.0, 0.0, 0.0, 0.0
384390
)
391+
)
385392
xr, xl, yd, yu = searchquader(point, tpar1, self.cpar, self.calib)
386393

387394
# print(f"xr = {xr}, xl = {xl}, yd = {yd}, yu = {yu}")
@@ -391,7 +398,8 @@ def test_searchquader(self):
391398
)
392399

393400
# Test with infinitely large values of tpar that should return about half the image size
394-
tpar2 = TrackPar(
401+
tpar2 = convert_track_par_to_tuple(
402+
TrackPar(
395403
1000.0,
396404
-1000.0,
397405
1000.0,
@@ -406,6 +414,7 @@ def test_searchquader(self):
406414
0.0,
407415
0.0,
408416
)
417+
)
409418

410419
xr, xl, yd, yu = searchquader(point, tpar2, self.cpar, self.calib)
411420

tests/test_tracking_run.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ def test_trackcorr_no_add(self):
125125
calib,
126126
10000.0,
127127
)
128-
run.tpar.add = 0
128+
129+
run.tpar = run.tpar._replace(add = 0)
129130
print(f"run.seq_par.first = {run.seq_par.first} run.seq_par.last = {run.seq_par.last}")
130131

131132
track_forward_start(run)
@@ -198,7 +199,7 @@ def test_trackcorr_add(self):
198199

199200
run.seq_par.first = 10240
200201
run.seq_par.last = 10250
201-
run.tpar.add = 1
202+
run.tpar = run.tpar._replace(add=1)
202203

203204
track_forward_start(run)
204205
trackcorr_c_loop(run, run.seq_par.first)
@@ -274,7 +275,7 @@ def test_trackback(self):
274275

275276
run.seq_par.first = 10240
276277
run.seq_par.last = 10250
277-
run.tpar.add = 1
278+
run.tpar = run.tpar._replace(add=1)
278279

279280
track_forward_start(run)
280281
trackcorr_c_loop(run, run.seq_par.first)
@@ -285,8 +286,14 @@ def test_trackback(self):
285286
trackcorr_c_finish(run, run.seq_par.last)
286287

287288

288-
run.tpar.dvxmin = run.tpar.dvymin = run.tpar.dvzmin = -50.0
289-
run.tpar.dvxmax = run.tpar.dvymax = run.tpar.dvzmax = 50.0
289+
run.tpar = run.tpar._replace(
290+
dvxmin = -50,
291+
dvymin = -50,
292+
dvzmin = -50.0,
293+
dvxmax = 50.0,
294+
dvymax = 50.0,
295+
dvzmax = 50.0,
296+
)
290297

291298

292299
run.lmax = vec_norm(
@@ -356,7 +363,7 @@ def test_new_particle(self):
356363
0.1,
357364
)
358365

359-
run.tpar.add = 0
366+
run.tpar = run.tpar._replace(add = 0)
360367

361368
track_forward_start(run)
362369
trackcorr_c_loop(run, 10001)
@@ -381,7 +388,7 @@ def test_new_particle(self):
381388
# calib,
382389
# 0.1,
383390
# )
384-
run.tpar.add = 1
391+
run.tpar = run.tpar._replace(add=1)
385392
track_forward_start(run)
386393
trackcorr_c_loop(run, 10001)
387394
trackcorr_c_loop(run, 10002)

tests/test_x_cavity.py.bck

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class TestCavity(unittest.TestCase):
131131
10000.0,
132132
)
133133

134-
run.tpar.add = 1
134+
run.tpar = run.tpar._replace(add=1)
135135
print("changed add particle to", run.tpar.add)
136136

137137
track_forward_start(run)

0 commit comments

Comments
 (0)