Skip to content

Commit 1abe173

Browse files
committed
feat: allow skipping video stats when saving ep
1 parent 0d359bf commit 1abe173

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

src/opentau/datasets/compute_stats.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -192,40 +192,61 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st
192192
}
193193

194194

195-
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
195+
def compute_episode_stats(
196+
episode_data: dict[str, list[str] | np.ndarray],
197+
features: dict,
198+
skip_video_stats: bool = False,
199+
) -> dict:
196200
"""Compute statistics for a single episode.
197201
198-
For image/video features, samples and downsamples images before computing stats.
202+
For image/video features, samples and downsamples images before computing stats
203+
(unless skip_video_stats is True, in which case placeholder stats are used).
199204
For other features, computes stats directly on the array data.
200205
201206
Args:
202207
episode_data: Dictionary mapping feature names to their data (arrays or image paths).
203208
features: Dictionary of feature specifications with 'dtype' keys.
209+
skip_video_stats: If True, do not compute real stats for image/video features;
210+
instead use placeholder stats (min=0, max=1, mean=0.5, std=0.5, count from data)
211+
so the output format remains valid.
204212
205213
Returns:
206214
Dictionary mapping feature names to their statistics (min, max, mean, std, count).
207-
Image statistics are normalized to [0, 1] range.
215+
Image statistics are normalized to [0, 1] range (or placeholders when skip_video_stats).
208216
"""
209217
ep_stats = {}
210218
for key, data in episode_data.items():
211219
if features[key]["dtype"] == "string":
212220
continue # HACK: we should receive np.arrays of strings
213221
elif features[key]["dtype"] in ["image", "video"]:
214-
ep_ft_array = sample_images(data) # data is a list of image paths
215-
axes_to_reduce = (0, 2, 3) # keep channel dim
216-
keepdims = True
222+
if skip_video_stats:
223+
# Placeholder stats: shape (3, 1, 1) for min/max/mean/std, count from length
224+
n_frames = len(data) if isinstance(data, list) else data.shape[0]
225+
shape = features[key]["shape"]
226+
# Expected shape for video is (C, H, W) e.g. (3, H, W)
227+
c = shape[0] if len(shape) >= 3 else 3
228+
ep_stats[key] = {
229+
"min": np.zeros((c, 1, 1), dtype=np.float64),
230+
"max": np.ones((c, 1, 1), dtype=np.float64),
231+
"mean": np.full((c, 1, 1), 0.5, dtype=np.float64),
232+
"std": np.full((c, 1, 1), 0.5, dtype=np.float64),
233+
"count": np.array([n_frames]),
234+
}
235+
else:
236+
image_paths = data.tolist() if isinstance(data, np.ndarray) else data
237+
ep_ft_array = sample_images(image_paths) # image_paths is list[str]
238+
axes_to_reduce = (0, 2, 3) # keep channel dim
239+
keepdims = True
240+
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
241+
# normalize and remove batch dim for images
242+
ep_stats[key] = {
243+
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
244+
}
217245
else:
218-
ep_ft_array = data # data is already a np.ndarray
219-
axes_to_reduce = 0 # compute stats over the first axis
220-
keepdims = data.ndim == 1 # keep as np.array
221-
222-
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
223-
224-
# finally, we normalize and remove batch dim for images
225-
if features[key]["dtype"] in ["image", "video"]:
226-
ep_stats[key] = {
227-
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
228-
}
246+
ep_ft_array = data if isinstance(data, np.ndarray) else np.asarray(data)
247+
axes_to_reduce = (0,) # compute stats over the first axis
248+
keepdims = ep_ft_array.ndim == 1 # keep as np.array
249+
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
229250

230251
return ep_stats
231252

src/opentau/datasets/lerobot_dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1670,7 +1670,9 @@ def save_episode(self, episode_data: dict | None = None) -> None:
16701670

16711671
self._wait_image_writer()
16721672
self._save_episode_table(episode_buffer, episode_index)
1673-
ep_stats = compute_episode_stats(episode_buffer, self.features)
1673+
ep_stats = compute_episode_stats(
1674+
episode_buffer, self.features, skip_video_stats=getattr(self, "skip_video_stats", False)
1675+
)
16741676

16751677
if len(self.meta.video_keys) > 0:
16761678
video_paths = self.encode_episode_videos(episode_index)
@@ -1682,9 +1684,11 @@ def save_episode(self, episode_data: dict | None = None) -> None:
16821684

16831685
ep_data_index, _ = get_episode_data_index(self.meta.episodes, [episode_index])
16841686
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
1687+
timestamps = np.asarray(episode_buffer["timestamp"]).reshape(-1)
1688+
episode_indices = np.full(episode_length, episode_index)
16851689
check_timestamps_sync(
1686-
episode_buffer["timestamp"],
1687-
episode_buffer["episode_index"],
1690+
timestamps,
1691+
episode_indices,
16881692
ep_data_index_np,
16891693
self.fps,
16901694
self.tolerance_s,
@@ -1870,6 +1874,7 @@ def create(
18701874
image_resample_strategy: str = "nearest",
18711875
vector_resample_strategy: str = "nearest",
18721876
standardize: bool = True,
1877+
skip_video_stats: bool = False,
18731878
) -> "LeRobotDataset":
18741879
"""Create a LeRobot Dataset from scratch in order to record data."""
18751880
obj = cls.__new__(cls)
@@ -1903,5 +1908,6 @@ def create(
19031908
obj.image_resample_strategy = image_resample_strategy
19041909
obj.vector_resample_strategy = vector_resample_strategy
19051910
obj.standardize = standardize
1911+
obj.skip_video_stats = skip_video_stats
19061912
obj.episode_data_index, obj.epi2idx = get_episode_data_index(obj.meta.episodes, obj.episodes)
19071913
return obj

0 commit comments

Comments
 (0)