Skip to content

Commit 7edb969

Browse files
authored
Stream tabular data (#1201)
* Streaming table writes works * Test streaming appends * array-ref needs a shape for the first update * Fix typing on tuples of ints * Test update that does not specify patch. * If no patch is given, slice to the current shape.
1 parent c4b1693 commit 7edb969

File tree

7 files changed

+321
-26
lines changed

7 files changed

+321
-26
lines changed

tiled/_tests/test_subscription.py

Lines changed: 165 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import uuid
55

66
import numpy as np
7+
import pandas as pd
8+
import pyarrow
79
import pytest
810
import tifffile
11+
from pandas.testing import assert_frame_equal
912
from starlette.testclient import WebSocketDenialResponse
1013

1114
from ..client import from_context
@@ -304,7 +307,8 @@ def callback(sub):
304307
assert event.wait(timeout=5.0), "Timeout waiting for messages"
305308

306309

307-
def test_subscribe_to_array_registered(tiled_websocket_context, tmp_path):
310+
def test_subscribe_to_array_registered_with_patch(tiled_websocket_context, tmp_path):
311+
"Writer specifies the region of the update (patch)."
308312
context = tiled_websocket_context
309313
client = from_context(context)
310314
container_sub = client.subscribe()
@@ -356,7 +360,7 @@ def on_child_created(update):
356360
data_sources=[data_source],
357361
metadata={},
358362
specs=[],
359-
key="test_subscribe_to_array_registered",
363+
key="test_subscribe_to_array_registered_with_patch",
360364
)
361365
actual = x.read() # smoke test
362366
np.testing.assert_array_equal(actual, arr[:2])
@@ -394,3 +398,162 @@ def on_child_created(update):
394398
assert update.patch.extend
395399
actual_streamed = update.data()
396400
np.testing.assert_array_equal(actual_streamed, arr[2:])
401+
402+
403+
def test_subscribe_to_array_registered_without_patch(tiled_websocket_context, tmp_path):
404+
"Writer does not specify the region of the update (patch)."
405+
context = tiled_websocket_context
406+
client = from_context(context)
407+
container_sub = client.subscribe()
408+
409+
updates = []
410+
event = threading.Event()
411+
412+
def on_array_updated(update):
413+
updates.append(update)
414+
event.set()
415+
416+
def on_child_created(update):
417+
array_sub = update.child().subscribe()
418+
array_sub.new_data.add_callback(on_array_updated)
419+
array_sub.start_in_thread(1)
420+
421+
container_sub.child_created.add_callback(on_child_created)
422+
423+
arr = np.random.random((3, 7, 13))
424+
tifffile.imwrite(tmp_path / "image1.tiff", arr[0])
425+
tifffile.imwrite(tmp_path / "image2.tiff", arr[1])
426+
427+
# Register just the first two images.
428+
structure = ArrayStructure.from_array(arr[:2])
429+
data_source = DataSource(
430+
management=Management.external,
431+
mimetype="multipart/related;type=image/tiff",
432+
structure_family=StructureFamily.array,
433+
structure=structure,
434+
assets=[
435+
Asset(
436+
data_uri=f"file://{tmp_path}/image1.tiff",
437+
is_directory=False,
438+
parameter="data_uris",
439+
num=1,
440+
),
441+
Asset(
442+
data_uri=f"file://{tmp_path}/image2.tiff",
443+
is_directory=False,
444+
parameter="data_uris",
445+
num=2,
446+
),
447+
],
448+
)
449+
450+
with container_sub.start_in_thread(1):
451+
x = client.new(
452+
structure_family=StructureFamily.array,
453+
data_sources=[data_source],
454+
metadata={},
455+
specs=[],
456+
key="test_subscribe_to_array_registered_without_patch",
457+
)
458+
actual = x.read() # smoke test
459+
np.testing.assert_array_equal(actual, arr[:2])
460+
# Add the third image.
461+
tifffile.imwrite(tmp_path / "image3.tiff", arr[2])
462+
updated_structure = ArrayStructure.from_array(arr[:])
463+
updated_data_source = copy.deepcopy(x.data_sources()[0])
464+
updated_data_source.structure = updated_structure
465+
updated_data_source.assets.append(
466+
Asset(
467+
data_uri=f"file://{tmp_path}/image3.tiff",
468+
is_directory=False,
469+
parameter="data_uris",
470+
num=3,
471+
),
472+
)
473+
x.context.http_client.put(
474+
x.uri.replace("/metadata/", "/data_source/", 1),
475+
content=safe_json_dump(
476+
{
477+
"data_source": updated_data_source,
478+
}
479+
),
480+
).raise_for_status()
481+
assert event.wait(timeout=5.0), "Timeout waiting for messages"
482+
x.close_stream()
483+
client.close_stream()
484+
x.refresh()
485+
actual_updated = x.read()
486+
np.testing.assert_array_equal(actual_updated, arr[:])
487+
(update,) = updates
488+
assert update.patch is None
489+
actual_streamed = update.data()
490+
np.testing.assert_array_equal(actual_streamed, arr[:])
491+
492+
493+
def test_streaming_table_write(tiled_websocket_context):
494+
context = tiled_websocket_context
495+
client = from_context(context)
496+
updates = []
497+
event = threading.Event()
498+
key = "test_streaming_table_write"
499+
500+
def collect(update):
501+
updates.append(update)
502+
event.set()
503+
504+
df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
505+
df2 = pd.DataFrame({"a": [7, 8, 9], "b": [10, 11, 12]})
506+
x = client.write_table(df1, key=key)
507+
508+
sub = client[key].subscribe()
509+
sub.new_data.add_callback(collect)
510+
with sub.start_in_thread(1):
511+
assert event.wait(timeout=5.0), "Timeout waiting for messages"
512+
actual = updates[0].data()
513+
assert_frame_equal(actual, df1)
514+
event.clear()
515+
x.write(df2)
516+
assert event.wait(timeout=5.0), "Timeout waiting for messages"
517+
assert not updates[1].append
518+
actual_updated = updates[1].data()
519+
assert_frame_equal(actual_updated, df2)
520+
521+
522+
def test_streaming_table_appends(tiled_websocket_context):
523+
context = tiled_websocket_context
524+
client = from_context(context)
525+
updates = []
526+
event = threading.Event()
527+
key = "test_streaming_table_append"
528+
529+
def collect(update):
530+
updates.append(update)
531+
event.set()
532+
533+
df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
534+
df2 = pd.DataFrame({"a": [7, 8, 9], "b": [10, 11, 12]})
535+
table1 = pyarrow.Table.from_pandas(df1, preserve_index=False)
536+
table2 = pyarrow.Table.from_pandas(df2, preserve_index=False)
537+
x = client.create_appendable_table(table1.schema, key=key)
538+
539+
sub = client[key].subscribe()
540+
sub.new_data.add_callback(collect)
541+
with sub.start_in_thread(1):
542+
x.append_partition(0, table1)
543+
assert event.wait(timeout=5.0), "Timeout waiting for messages"
544+
assert updates[0].append
545+
streamed1 = updates[0].data()
546+
streamed1_pyarrow = pyarrow.Table.from_pandas(streamed1, preserve_index=False)
547+
assert streamed1_pyarrow == table1
548+
event.clear()
549+
x.append_partition(0, table2)
550+
assert event.wait(timeout=5.0), "Timeout waiting for messages"
551+
assert updates[1].append
552+
streamed2 = updates[1].data()
553+
streamed2_pyarrow = pyarrow.Table.from_pandas(streamed2, preserve_index=False)
554+
assert streamed2_pyarrow == table2
555+
streaming_combined = pyarrow.concat_tables(
556+
[streamed1_pyarrow, streamed2_pyarrow]
557+
)
558+
expected_combined = pyarrow.concat_tables([table1, table2])
559+
assert streaming_combined == expected_combined

tiled/catalog/adapter.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -879,6 +879,7 @@ async def put_data_source(self, data_source, patch):
879879
"timestamp": datetime.now().isoformat(),
880880
"data_source": data_source.model_dump(),
881881
"patch": patch.model_dump() if patch else None,
882+
"shape": structure["shape"],
882883
}
883884

884885
# Cache data in Redis with a TTL, and publish
@@ -1279,28 +1280,59 @@ class CatalogSparseAdapter(CatalogArrayAdapter):
12791280

12801281

12811282
class CatalogTableAdapter(CatalogNodeAdapter):
1283+
def make_ws_schema(self):
1284+
return {
1285+
"type": "table-schema",
1286+
"version": 1,
1287+
"arrow_schema": self.structure().arrow_schema,
1288+
}
1289+
1290+
async def _stream(self, media_type, entry, body, partition, append):
1291+
sequence = await self.context.streaming_cache.incr_seq(self.node.id)
1292+
metadata = {
1293+
"type": "table-data",
1294+
"sequence": sequence,
1295+
"timestamp": datetime.now().isoformat(),
1296+
"mimetype": media_type,
1297+
"partition": partition,
1298+
"append": append,
1299+
}
1300+
1301+
await self.context.streaming_cache.set(
1302+
self.node.id, sequence, metadata, payload=body
1303+
)
1304+
12821305
async def get(self, *args, **kwargs):
12831306
return (await self.get_adapter()).get(*args, **kwargs)
12841307

12851308
async def read(self, *args, **kwargs):
12861309
return await ensure_awaitable((await self.get_adapter()).read, *args, **kwargs)
12871310

1288-
async def write(self, *args, **kwargs):
1289-
return await ensure_awaitable((await self.get_adapter()).write, *args, **kwargs)
1311+
async def write(self, media_type, deserializer, entry, body):
1312+
if self.context.streaming_cache:
1313+
await self._stream(media_type, entry, body, None, False)
1314+
data = await ensure_awaitable(deserializer, body)
1315+
return await ensure_awaitable((await self.get_adapter()).write, data)
12901316

12911317
async def read_partition(self, *args, **kwargs):
12921318
return await ensure_awaitable(
12931319
(await self.get_adapter()).read_partition, *args, **kwargs
12941320
)
12951321

1296-
async def write_partition(self, *args, **kwargs):
1322+
async def write_partition(self, media_type, deserializer, entry, body, partition):
1323+
if self.context.streaming_cache:
1324+
await self._stream(media_type, entry, body, partition, False)
1325+
data = await ensure_awaitable(deserializer, body)
12971326
return await ensure_awaitable(
1298-
(await self.get_adapter()).write_partition, *args, **kwargs
1327+
(await self.get_adapter()).write_partition, partition, data
12991328
)
13001329

1301-
async def append_partition(self, *args, **kwargs):
1330+
async def append_partition(self, media_type, deserializer, entry, body, partition):
1331+
if self.context.streaming_cache:
1332+
await self._stream(media_type, entry, body, partition, True)
1333+
data = await ensure_awaitable(deserializer, body)
13021334
return await ensure_awaitable(
1303-
(await self.get_adapter()).append_partition, *args, **kwargs
1335+
(await self.get_adapter()).append_partition, partition, data
13041336
)
13051337

13061338

tiled/client/dataframe.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import concurrent.futures
12
import functools
23
import warnings
4+
from typing import TYPE_CHECKING, Optional
35
from urllib.parse import parse_qs, urlparse
46

57
import dask
@@ -18,6 +20,9 @@
1820
retry_context,
1921
)
2022

23+
if TYPE_CHECKING:
24+
from .stream import TableSubscription
25+
2126
_EXTRA_CHARS_PER_ITEM = len("&column=")
2227

2328

@@ -302,6 +307,25 @@ def export(self, filepath, columns=None, *, format=None):
302307
params=params,
303308
)
304309

310+
def subscribe(
311+
self,
312+
executor: Optional[concurrent.futures.Executor] = None,
313+
) -> "TableSubscription":
314+
"""
315+
Subscribe to streaming updates about this table.
316+
317+
Returns
318+
-------
319+
subscription : Subscription
320+
executor : concurrent.futures.Executor, optional
321+
Launches tasks asynchronously, in response to updates. By default,
322+
a concurrent.futures.ThreadPoolExecutor is used.
323+
"""
324+
# Keep this import here to defer the websockets import until/unless needed.
325+
from .stream import TableSubscription
326+
327+
return TableSubscription(self.context, self.path_parts, executor)
328+
305329

306330
# Subclass with a public class that adds the dask-specific methods.
307331

tiled/client/stream.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
ChildMetadataUpdated,
3131
ContainerSchema,
3232
Schema,
33+
TableData,
34+
TableSchema,
3335
Update,
3436
)
3537
from ..structures.core import STRUCTURE_TYPES, StructureFamily
@@ -497,6 +499,36 @@ def process(self, update: Update):
497499
self.new_data.process(update)
498500

499501

502+
class TableSubscription(Subscription):
503+
"""
504+
Subscribe to streaming updates from an table.
505+
506+
Parameters
507+
----------
508+
context : tiled.client.Context
509+
Provides connection to Tiled server
510+
segments : list[str]
511+
Path to node of interest, given as a list of path segments
512+
executor : concurrent.futures.Executor, optional
513+
Launches tasks asynchronously, in response to updates. By default,
514+
a concurrent.futures.ThreadPoolExecutor is used.
515+
"""
516+
517+
def __init__(
518+
self,
519+
context: Context,
520+
segments: List[str] = None,
521+
executor: Optional[concurrent.futures.Executor] = None,
522+
):
523+
super().__init__(context, segments, executor)
524+
self.new_data: CallbackRegistry["LiveTableData"] = CallbackRegistry(
525+
self.executor
526+
)
527+
528+
def process(self, update: Update):
529+
self.new_data.process(update)
530+
531+
500532
class UnparseableMessage(RuntimeError):
501533
"Message can be decoded but cannot be interpreted by the application"
502534
pass
@@ -609,16 +641,38 @@ def data(self):
609641
).read()
610642
# Decode payload (bytes) into array.
611643
numpy_dtype = self.data_type.to_numpy_dtype()
612-
return numpy.frombuffer(content, dtype=numpy_dtype).reshape(self.patch.shape)
644+
if self.patch:
645+
shape = self.patch.shape
646+
else:
647+
shape = self.shape
648+
return numpy.frombuffer(content, dtype=numpy_dtype).reshape(shape)
649+
650+
651+
class LiveTableData(TableData):
652+
model_config = ConfigDict(arbitrary_types_allowed=True)
653+
subscription: TableSubscription
654+
655+
def data(self):
656+
"Get table"
657+
# Registration occurs on import. Ensure this is imported.
658+
from ..serialization import array
659+
660+
del array
661+
662+
# Decode payload (bytes) into array.
663+
deserializer = default_deserialization_registry.dispatch("table", self.mimetype)
664+
return deserializer(self.payload)
613665

614666

615667
SCHEMA_MESSAGE_TYPES = {
616668
"array-schema": ArraySchema,
617669
"container-schema": ContainerSchema,
670+
"table-schema": TableSchema,
618671
}
619672
UPDATE_MESSAGE_TYPES = {
620673
"container-child-created": LiveChildCreated,
621674
"container-child-metadata-updated": LiveChildMetadataUpdated,
622675
"array-data": LiveArrayData,
623676
"array-ref": LiveArrayRef,
677+
"table-data": LiveTableData,
624678
}

0 commit comments

Comments
 (0)