Skip to content

Commit 92a0f35

Browse files
committed
Merge branch 'idw_interpolation'
2 parents 373e787 + be03754 commit 92a0f35

File tree

7 files changed

+414
-11
lines changed

7 files changed

+414
-11
lines changed

bump_version.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,31 @@ def bump_version(version_type):
5353
new_version = f"{major}.{minor}.{patch}"
5454
print(f"New version: {new_version}")
5555

56-
# Replace the version in the file
56+
# Replace the version in __init__.py
5757
new_content = re.sub(
5858
r"__version__\s*=\s*['\"]([^'\"]*)['\"]",
5959
f"__version__ = '{new_version}'",
6060
content
6161
)
62-
63-
# Write the updated content back to the file
6462
init_file.write_text(new_content)
65-
print(f"Version bumped to {new_version}")
63+
print(f"Updated flowtracks/__init__.py to version {new_version}")
64+
65+
# Also update version in pyproject.toml
66+
pyproject_file = Path('pyproject.toml')
67+
if pyproject_file.exists():
68+
pyproject_content = pyproject_file.read_text()
69+
if '[project]' in pyproject_content:
70+
pyproject_new = re.sub(
71+
r'version\s*=\s*["\"][^"\"]*["\"]',
72+
f'version = "{new_version}"',
73+
pyproject_content
74+
)
75+
pyproject_file.write_text(pyproject_new)
76+
print(f"Updated pyproject.toml to version {new_version}")
77+
else:
78+
print("Warning: No [project] table found in pyproject.toml, version not updated there.")
79+
else:
80+
print("Warning: pyproject.toml not found, version not updated there.")
6681

6782
if __name__ == "__main__":
6883
if len(sys.argv) != 2:

flowtracks/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = '1.1.0'
3+
__version__ = '1.1.1'
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Boilerplate for 3D-PTV trajectory post-processing
4+
using xarray + dask + zarr
5+
6+
Features:
7+
- Ragged array encoding for trajectories of different lengths
8+
- Vector-style storage (obs, component)
9+
- Spline smoothing and derivative calculation (velocity, acceleration)
10+
- Resampling onto a uniform time base
11+
- Streaming mode: append processed trajectories to a Zarr store
12+
"""
13+
14+
import numpy as np
15+
import xarray as xr
16+
from scipy.interpolate import UnivariateSpline
17+
import zarr
18+
19+
# ---------------------------------------------------------------------
20+
# 1. Example: Ragged array encoding for variable-length trajectories
21+
# ---------------------------------------------------------------------
22+
23+
def build_ragged_example():
24+
traj1_t = np.array([0, 1, 2])
25+
traj1_x = np.array([0, 1, 2])
26+
traj1_y = np.array([0, 1, 4])
27+
traj1_z = np.array([0, 0, 0])
28+
29+
traj2_t = np.array([0, 2, 4, 6, 8])
30+
traj2_x = np.array([0, 2, 4, 6, 8])
31+
traj2_y = np.array([0, -1, -2, -3, -4])
32+
traj2_z = np.array([0, 1, 0, -1, 0])
33+
34+
# Concatenate into ragged structure
35+
times = np.concatenate([traj1_t, traj2_t])
36+
positions = np.vstack([
37+
np.stack([traj1_x, traj1_y, traj1_z], axis=-1),
38+
np.stack([traj2_x, traj2_y, traj2_z], axis=-1)
39+
])
40+
trajectory_id = np.concatenate([
41+
np.full(traj1_t.shape, 0),
42+
np.full(traj2_t.shape, 1)
43+
])
44+
45+
ds = xr.Dataset(
46+
{
47+
"t": ("obs", times),
48+
"pos": (("obs", "component"), positions),
49+
"trajectory": ("obs", trajectory_id),
50+
},
51+
coords={"component": ["x", "y", "z"], "obs": np.arange(len(times))}
52+
)
53+
54+
return ds
55+
56+
57+
# ---------------------------------------------------------------------
58+
# 2. Compute derivatives (velocity, acceleration) in ragged array
59+
# ---------------------------------------------------------------------
60+
61+
def compute_derivatives_ragged(ds):
62+
def _derivs(sub):
63+
dt = np.gradient(sub.t.values)
64+
dpos = np.gradient(sub.pos.values, axis=0)
65+
vel = dpos / dt[:, None]
66+
acc = np.gradient(vel, axis=0) / dt[:, None]
67+
return xr.Dataset({
68+
"vel": (("obs", "component"), vel),
69+
"acc": (("obs", "component"), acc)
70+
})
71+
72+
derivs = ds.groupby("trajectory").map(_derivs)
73+
return xr.merge([ds, derivs])
74+
75+
76+
# ---------------------------------------------------------------------
77+
# 3. Spline smoothing + resampling on uniform time base
78+
# ---------------------------------------------------------------------
79+
80+
def smooth_and_resample(t, pos, t_uniform, s=0.0):
81+
"""Smooth trajectory with spline, resample to uniform time base.
82+
Returns position, velocity, acceleration arrays of shape (len(t_uniform), 3).
83+
"""
84+
comps = []
85+
vels = []
86+
accs = []
87+
for d in range(pos.shape[1]): # loop over x,y,z
88+
spline = UnivariateSpline(t, pos[:, d], s=s)
89+
p = spline(t_uniform)
90+
v = spline.derivative(1)(t_uniform)
91+
a = spline.derivative(2)(t_uniform)
92+
comps.append(p)
93+
vels.append(v)
94+
accs.append(a)
95+
96+
pos_u = np.stack(comps, axis=-1)
97+
vel_u = np.stack(vels, axis=-1)
98+
acc_u = np.stack(accs, axis=-1)
99+
return pos_u, vel_u, acc_u
100+
101+
102+
# ---------------------------------------------------------------------
103+
# 4. Streaming mode: append processed trajectories to Zarr
104+
# ---------------------------------------------------------------------
105+
106+
def init_zarr_store(store_path, t_uniform):
107+
components = ["x", "y", "z"]
108+
109+
ds = xr.Dataset(
110+
data_vars={
111+
"position": (("trajectory", "time", "component"),
112+
np.empty((0, len(t_uniform), len(components)))),
113+
"velocity": (("trajectory", "time", "component"),
114+
np.empty((0, len(t_uniform), len(components)))),
115+
"acceleration": (("trajectory", "time", "component"),
116+
np.empty((0, len(t_uniform), len(components))))
117+
},
118+
coords={
119+
"time": t_uniform,
120+
"component": components,
121+
"trajectory": []
122+
}
123+
)
124+
ds.to_zarr(store_path, mode="w")
125+
126+
127+
def append_to_zarr(store_path, traj_id, t, pos, t_uniform, s=0.0):
128+
pos_u, vel_u, acc_u = smooth_and_resample(t, pos, t_uniform, s=s)
129+
130+
new = xr.Dataset(
131+
{
132+
"position": (("trajectory", "time", "component"), pos_u[np.newaxis, ...]),
133+
"velocity": (("trajectory", "time", "component"), vel_u[np.newaxis, ...]),
134+
"acceleration": (("trajectory", "time", "component"), acc_u[np.newaxis, ...]),
135+
},
136+
coords={
137+
"trajectory": [traj_id],
138+
"time": t_uniform,
139+
"component": ["x", "y", "z"]
140+
}
141+
)
142+
new.to_zarr(store_path, mode="a", append_dim="trajectory")
143+
144+
145+
# ---------------------------------------------------------------------
146+
# 5. Example usage
147+
# ---------------------------------------------------------------------
148+
149+
if __name__ == "__main__":
150+
# Step 1: Build ragged example
151+
ds_ragged = build_ragged_example()
152+
print("Ragged dataset:")
153+
print(ds_ragged)
154+
155+
# Step 2: Compute derivatives in ragged array
156+
ds_with_derivs = compute_derivatives_ragged(ds_ragged)
157+
print("\nRagged dataset with velocity and acceleration:")
158+
print(ds_with_derivs)
159+
160+
# Step 3+4: Streaming to Zarr
161+
t_uniform = np.linspace(0, 8, 81) # uniform time base (0.1s step)
162+
store = "trajectories.zarr"
163+
init_zarr_store(store, t_uniform)
164+
165+
# Add first trajectory
166+
obs0 = ds_ragged.where(ds_ragged.trajectory == 0, drop=True)
167+
append_to_zarr(store, traj_id=0, t=obs0.t.values, pos=obs0.pos.values, t_uniform=t_uniform, s=0.1)
168+
169+
# Add second trajectory
170+
obs1 = ds_ragged.where(ds_ragged.trajectory == 1, drop=True)
171+
append_to_zarr(store, traj_id=1, t=obs1.t.values, pos=obs1.pos.values, t_uniform=t_uniform, s=0.1)
172+
173+
# Open the final store lazily with xarray+dask
174+
ds_zarr = xr.open_zarr(store)
175+
print("\nZarr dataset (streamed trajectories):")
176+
print(ds_zarr)

flowtracks/interpolation.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -640,13 +640,17 @@ def weights(self, dists, use_parts, unused_marker=None):
640640
interpolation point 1..m.
641641
642642
Returns:
643-
weights - an (m,n) array.
643+
weights - an (m,k) array.
644644
"""
645-
if unused_marker is None:
646-
unused_marker = self.field_positions().shape[0]
647-
648645
weights = dists**-self._par
649-
weights[use_parts == unused_marker] = 0.
646+
# If use_parts is boolean, just mask out non-neighbors
647+
if use_parts.dtype == bool:
648+
weights[~use_parts] = 0.
649+
else:
650+
# If use_parts is indices, mask out invalid indices
651+
if unused_marker is None:
652+
unused_marker = dists.shape[1]
653+
weights[use_parts == unused_marker] = 0.
650654
return weights
651655

652656
def set_scene(self, tracer_pos, interp_points,
@@ -712,7 +716,42 @@ def __call__(self, tracer_pos, interp_points, data, companionship=None):
712716
vel_interp - an (m,3) array with the interpolated value at the position
713717
of each particle, [m/s].
714718
"""
715-
raise NotImplementedError()
719+
if len(tracer_pos) == 0:
720+
warnings.warn("No tracers in frame, interpolation returned zeros.")
721+
ret_shape = data.shape[-1] if data.ndim > 1 else 1
722+
return np.zeros((interp_points.shape[0], ret_shape))
723+
724+
dists, use_parts = select_neighbs(tracer_pos, interp_points,
725+
self._radius, self._neighbs,
726+
companionship)
727+
728+
m, n = dists.shape
729+
if data.ndim == 1:
730+
data = data[:, None]
731+
matched_data = np.zeros((m, n, data.shape[1]))
732+
for i in range(m):
733+
matched_data[i] = data
734+
735+
# Handle exact matches: if any distance is zero, set result to that data value
736+
exact_match = (dists == 0)
737+
has_exact = exact_match.any(axis=1)
738+
vel_interp = np.empty((m, data.shape[1]), dtype=data.dtype)
739+
weights = self.weights(dists, use_parts)
740+
for i in range(m):
741+
if has_exact[i]:
742+
# Use the first exact match (should only be one)
743+
idx = np.where(exact_match[i])[0][0]
744+
vel_interp[i] = matched_data[i, idx]
745+
else:
746+
sum_weights = weights[i].sum()
747+
if sum_weights == 0:
748+
vel_interp[i] = 0
749+
else:
750+
vel_interp[i] = (weights[i][:, None] * matched_data[i]).sum(axis=0) / sum_weights
751+
# Always return a 2D array (m, d) even for single-point or 1D data
752+
if vel_interp.ndim == 1:
753+
vel_interp = vel_interp[:, None]
754+
return vel_interp
716755

717756
def _meth_interp(self, act_neighbs=None):
718757
"""

pyproject.toml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,33 @@
11
[build-system]
22
requires = ["setuptools>=42", "wheel"]
33
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "flowtracks"
7+
version = "1.1.1"
8+
description = "Library for handling of PTV trajectory database."
9+
readme = "README.md"
10+
requires-python = ">=3.9"
11+
authors = [{ name = "Yosef Meller", email = "yosefm@gmail.com" }]
12+
license = { file = "LICENSE.txt" }
13+
dependencies = [
14+
"numpy",
15+
"scipy",
16+
"tables"
17+
]
18+
classifiers = [
19+
"Development Status :: 5 - Production/Stable",
20+
"Intended Audience :: Science/Research",
21+
"License :: OSI Approved :: BSD License",
22+
"Programming Language :: Python :: 3",
23+
"Programming Language :: Python :: 3.8",
24+
"Programming Language :: Python :: 3.9",
25+
"Programming Language :: Python :: 3.10",
26+
"Programming Language :: Python :: 3.11",
27+
"Topic :: Scientific/Engineering"
28+
]
29+
30+
[dependency-groups]
31+
dev = [
32+
"pytest>=8.4.2",
33+
]

tests/test_idw_call.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
import numpy as np
3+
from flowtracks import interpolation
4+
5+
class TestIDWCall(unittest.TestCase):
6+
def setUp(self):
7+
# Simple 2D case for easy checking
8+
self.tracer_pos = np.array([
9+
[0.0, 0.0, 0.0],
10+
[1.0, 0.0, 0.0],
11+
[0.0, 1.0, 0.0],
12+
[1.0, 1.0, 0.0],
13+
])
14+
self.data = np.array([
15+
[1.0, 0.0, 0.0],
16+
[0.0, 1.0, 0.0],
17+
[0.0, 0.0, 1.0],
18+
[1.0, 1.0, 1.0],
19+
])
20+
self.interp_points = np.array([
21+
[0.5, 0.5, 0.0],
22+
])
23+
24+
def test_basic_idw(self):
25+
idw = interpolation.InverseDistanceWeighter(num_neighbs=4, param=1)
26+
result = idw(self.tracer_pos, self.interp_points, self.data)
27+
self.assertEqual(result.shape, (1, 3))
28+
# Should be a weighted average, check sum is reasonable
29+
self.assertTrue(np.all(result >= 0))
30+
self.assertTrue(np.all(result <= 1))
31+
32+
def test_empty_tracers(self):
33+
import warnings
34+
idw = interpolation.InverseDistanceWeighter(num_neighbs=2, param=1)
35+
with warnings.catch_warnings(record=True) as w:
36+
warnings.simplefilter("always")
37+
result = idw(np.empty((0, 3)), self.interp_points, np.empty((0, 3)))
38+
self.assertTrue(any(issubclass(warn.category, UserWarning) for warn in w))
39+
self.assertEqual(result.shape, (1, 3))
40+
self.assertTrue(np.all(result == 0))
41+
42+
def test_1d_data(self):
43+
data_1d = np.array([1.0, 2.0, 3.0, 4.0])
44+
idw = interpolation.InverseDistanceWeighter(num_neighbs=4, param=1)
45+
result = idw(self.tracer_pos, self.interp_points, data_1d)
46+
self.assertEqual(result.shape, (1, 1))
47+
self.assertTrue(np.all(result >= 1))
48+
self.assertTrue(np.all(result <= 4))
49+
50+
def test_companionship(self):
51+
# Exclude the first tracer from the interpolation
52+
import warnings
53+
companions = np.array([0])
54+
idw = interpolation.InverseDistanceWeighter(num_neighbs=4, param=1)
55+
with warnings.catch_warnings(record=True) as w:
56+
warnings.simplefilter("always")
57+
result = idw(self.tracer_pos, self.interp_points, self.data, companionship=companions)
58+
self.assertTrue(any(issubclass(warn.category, RuntimeWarning) for warn in w))
59+
self.assertEqual(result.shape, (1, 3))
60+
# The result should not be equal to the average including the first tracer
61+
idw_all = interpolation.InverseDistanceWeighter(num_neighbs=4, param=1)
62+
result_all = idw_all(self.tracer_pos, self.interp_points, self.data)
63+
self.assertFalse(np.allclose(result, result_all))

0 commit comments

Comments
 (0)