Skip to content

Commit 144696f

Browse files
authored
Fix normalize tests and add multi-batch dimension test (#623)
## Summary - Fix test_normalize_update to properly handle vector batching in RunningStats - Fix test_serialize_deserialize to use multiple vectors (required by get_statistics) - Add test_multiple_batch_dimensions to verify handling of complex batch shapes like (2, 3, 4) ## Test plan - [x] All tests pass: `python -m pytest src/openpi/shared/normalize_test.py -v` - [x] Verified the new test covers multiple batch dimensions correctly 🤖 Generated with [Claude Code](https://claude.ai/code)
2 parents 0fffc72 + 6eb4446 commit 144696f

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

src/openpi/shared/normalize_test.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,40 @@
44

55

66
def test_normalize_update():
7-
arr = np.arange(12)
7+
arr = np.arange(12).reshape(4, 3) # 4 vectors of length 3
88

99
stats = normalize.RunningStats()
10-
for i in range(0, len(arr), 3):
11-
stats.update(arr[i : i + 3])
10+
for i in range(len(arr)):
11+
stats.update(arr[i : i + 1]) # Update with one vector at a time
1212
results = stats.get_statistics()
1313

14-
assert np.allclose(results.mean, np.mean(arr))
15-
assert np.allclose(results.std, np.std(arr))
14+
assert np.allclose(results.mean, np.mean(arr, axis=0))
15+
assert np.allclose(results.std, np.std(arr, axis=0))
1616

1717

1818
def test_serialize_deserialize():
1919
stats = normalize.RunningStats()
20-
stats.update(np.arange(12))
20+
stats.update(np.arange(12).reshape(4, 3)) # 4 vectors of length 3
2121

2222
norm_stats = {"test": stats.get_statistics()}
2323
norm_stats2 = normalize.deserialize_json(normalize.serialize_json(norm_stats))
2424
assert np.allclose(norm_stats["test"].mean, norm_stats2["test"].mean)
2525
assert np.allclose(norm_stats["test"].std, norm_stats2["test"].std)
26+
27+
28+
def test_multiple_batch_dimensions():
29+
# Test with multiple batch dimensions: (2, 3, 4) where 4 is vector dimension
30+
batch_shape = (2, 3, 4)
31+
arr = np.random.rand(*batch_shape)
32+
33+
stats = normalize.RunningStats()
34+
stats.update(arr) # Should handle (2, 3, 4) -> reshape to (6, 4)
35+
results = stats.get_statistics()
36+
37+
# Flatten batch dimensions and compute expected stats
38+
flattened = arr.reshape(-1, arr.shape[-1]) # (6, 4)
39+
expected_mean = np.mean(flattened, axis=0)
40+
expected_std = np.std(flattened, axis=0)
41+
42+
assert np.allclose(results.mean, expected_mean)
43+
assert np.allclose(results.std, expected_std)

0 commit comments

Comments
 (0)