Skip to content

Commit 35b4a45

Browse files
laramielcopybara-github
authored andcommitted
add a test demonstrating ts.Batch
PiperOrigin-RevId: 826537364 Change-Id: I59509574e999571278f6f556e6f9c9bf2d9effc4
1 parent 6d4d140 commit 35b4a45

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

python/tensorstore/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,17 @@ tensorstore_pytest_test(
797797
],
798798
)
799799

800+
tensorstore_pytest_test(
801+
name = "batch_test",
802+
size = "small",
803+
srcs = ["tests/batch_test.py"],
804+
deps = [
805+
":conftest",
806+
":tensorstore",
807+
"@pypa_numpy//:numpy",
808+
],
809+
)
810+
800811
tensorstore_pytest_test(
801812
name = "future_test",
802813
size = "medium",
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 The TensorStore Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for tensorstore.Batch."""
15+
16+
import numpy as np
17+
import pytest
18+
import tensorstore as ts
19+
20+
21+
async def _get_store() -> ts.TensorStore:
22+
store = await ts.open(
23+
{
24+
'driver': 'zarr',
25+
'kvstore': 'memory://',
26+
},
27+
dtype=ts.uint32,
28+
shape=[10, 20],
29+
chunk_layout=ts.ChunkLayout(read_chunk_shape=[5, 10]),
30+
create=True,
31+
delete_existing=True,
32+
)
33+
await store.write(np.arange(200, dtype=np.uint32).reshape((10, 20)))
34+
return store
35+
36+
37+
async def test_batch_submit():
38+
store = await _get_store()
39+
batch = ts.Batch()
40+
f1 = store[:5, :10].read(batch=batch)
41+
f2 = store[5:, 10:].read(batch=batch)
42+
batch.submit()
43+
a1 = await f1
44+
a2 = await f2
45+
np.testing.assert_array_equal(
46+
np.arange(200, dtype=np.uint32).reshape((10, 20))[:5, :10], a1
47+
)
48+
np.testing.assert_array_equal(
49+
np.arange(200, dtype=np.uint32).reshape((10, 20))[5:, 10:], a2
50+
)
51+
52+
53+
async def test_batch_del():
54+
store = await _get_store()
55+
batch = ts.Batch()
56+
f1 = store[:5, :10].read(batch=batch)
57+
f2 = store[5:, 10:].read(batch=batch)
58+
del batch
59+
a1 = await f1
60+
a2 = await f2
61+
np.testing.assert_array_equal(
62+
np.arange(200, dtype=np.uint32).reshape((10, 20))[:5, :10], a1
63+
)
64+
np.testing.assert_array_equal(
65+
np.arange(200, dtype=np.uint32).reshape((10, 20))[5:, 10:], a2
66+
)
67+
68+
69+
async def test_batch_submitted_error():
70+
store = await _get_store()
71+
batch = ts.Batch()
72+
batch.submit()
73+
with pytest.raises(ValueError, match='batch was already submitted'):
74+
store[:5, :10].read(batch=batch)

0 commit comments

Comments
 (0)