|
4 | 4 | import uuid |
5 | 5 |
|
6 | 6 | import numpy as np |
| 7 | +import pandas as pd |
| 8 | +import pyarrow |
7 | 9 | import pytest |
8 | 10 | import tifffile |
| 11 | +from pandas.testing import assert_frame_equal |
9 | 12 | from starlette.testclient import WebSocketDenialResponse |
10 | 13 |
|
11 | 14 | from ..client import from_context |
@@ -304,7 +307,8 @@ def callback(sub): |
304 | 307 | assert event.wait(timeout=5.0), "Timeout waiting for messages" |
305 | 308 |
|
306 | 309 |
|
307 | | -def test_subscribe_to_array_registered(tiled_websocket_context, tmp_path): |
| 310 | +def test_subscribe_to_array_registered_with_patch(tiled_websocket_context, tmp_path): |
| 311 | + "Writer specifies the region of the update (patch)." |
308 | 312 | context = tiled_websocket_context |
309 | 313 | client = from_context(context) |
310 | 314 | container_sub = client.subscribe() |
@@ -356,7 +360,7 @@ def on_child_created(update): |
356 | 360 | data_sources=[data_source], |
357 | 361 | metadata={}, |
358 | 362 | specs=[], |
359 | | - key="test_subscribe_to_array_registered", |
| 363 | + key="test_subscribe_to_array_registered_with_patch", |
360 | 364 | ) |
361 | 365 | actual = x.read() # smoke test |
362 | 366 | np.testing.assert_array_equal(actual, arr[:2]) |
@@ -394,3 +398,162 @@ def on_child_created(update): |
394 | 398 | assert update.patch.extend |
395 | 399 | actual_streamed = update.data() |
396 | 400 | np.testing.assert_array_equal(actual_streamed, arr[2:]) |
| 401 | + |
| 402 | + |
| 403 | +def test_subscribe_to_array_registered_without_patch(tiled_websocket_context, tmp_path): |
| 404 | + "Writer does not specify the region of the update (patch)." |
| 405 | + context = tiled_websocket_context |
| 406 | + client = from_context(context) |
| 407 | + container_sub = client.subscribe() |
| 408 | + |
| 409 | + updates = [] |
| 410 | + event = threading.Event() |
| 411 | + |
| 412 | + def on_array_updated(update): |
| 413 | + updates.append(update) |
| 414 | + event.set() |
| 415 | + |
| 416 | + def on_child_created(update): |
| 417 | + array_sub = update.child().subscribe() |
| 418 | + array_sub.new_data.add_callback(on_array_updated) |
| 419 | + array_sub.start_in_thread(1) |
| 420 | + |
| 421 | + container_sub.child_created.add_callback(on_child_created) |
| 422 | + |
| 423 | + arr = np.random.random((3, 7, 13)) |
| 424 | + tifffile.imwrite(tmp_path / "image1.tiff", arr[0]) |
| 425 | + tifffile.imwrite(tmp_path / "image2.tiff", arr[1]) |
| 426 | + |
| 427 | + # Register just the first two images. |
| 428 | + structure = ArrayStructure.from_array(arr[:2]) |
| 429 | + data_source = DataSource( |
| 430 | + management=Management.external, |
| 431 | + mimetype="multipart/related;type=image/tiff", |
| 432 | + structure_family=StructureFamily.array, |
| 433 | + structure=structure, |
| 434 | + assets=[ |
| 435 | + Asset( |
| 436 | + data_uri=f"file://{tmp_path}/image1.tiff", |
| 437 | + is_directory=False, |
| 438 | + parameter="data_uris", |
| 439 | + num=1, |
| 440 | + ), |
| 441 | + Asset( |
| 442 | + data_uri=f"file://{tmp_path}/image2.tiff", |
| 443 | + is_directory=False, |
| 444 | + parameter="data_uris", |
| 445 | + num=2, |
| 446 | + ), |
| 447 | + ], |
| 448 | + ) |
| 449 | + |
| 450 | + with container_sub.start_in_thread(1): |
| 451 | + x = client.new( |
| 452 | + structure_family=StructureFamily.array, |
| 453 | + data_sources=[data_source], |
| 454 | + metadata={}, |
| 455 | + specs=[], |
| 456 | + key="test_subscribe_to_array_registered_without_patch", |
| 457 | + ) |
| 458 | + actual = x.read() # smoke test |
| 459 | + np.testing.assert_array_equal(actual, arr[:2]) |
| 460 | + # Add the third image. |
| 461 | + tifffile.imwrite(tmp_path / "image3.tiff", arr[2]) |
| 462 | + updated_structure = ArrayStructure.from_array(arr[:]) |
| 463 | + updated_data_source = copy.deepcopy(x.data_sources()[0]) |
| 464 | + updated_data_source.structure = updated_structure |
| 465 | + updated_data_source.assets.append( |
| 466 | + Asset( |
| 467 | + data_uri=f"file://{tmp_path}/image3.tiff", |
| 468 | + is_directory=False, |
| 469 | + parameter="data_uris", |
| 470 | + num=3, |
| 471 | + ), |
| 472 | + ) |
| 473 | + x.context.http_client.put( |
| 474 | + x.uri.replace("/metadata/", "/data_source/", 1), |
| 475 | + content=safe_json_dump( |
| 476 | + { |
| 477 | + "data_source": updated_data_source, |
| 478 | + } |
| 479 | + ), |
| 480 | + ).raise_for_status() |
| 481 | + assert event.wait(timeout=5.0), "Timeout waiting for messages" |
| 482 | + x.close_stream() |
| 483 | + client.close_stream() |
| 484 | + x.refresh() |
| 485 | + actual_updated = x.read() |
| 486 | + np.testing.assert_array_equal(actual_updated, arr[:]) |
| 487 | + (update,) = updates |
| 488 | + assert update.patch is None |
| 489 | + actual_streamed = update.data() |
| 490 | + np.testing.assert_array_equal(actual_streamed, arr[:]) |
| 491 | + |
| 492 | + |
| 493 | +def test_streaming_table_write(tiled_websocket_context): |
| 494 | + context = tiled_websocket_context |
| 495 | + client = from_context(context) |
| 496 | + updates = [] |
| 497 | + event = threading.Event() |
| 498 | + key = "test_streaming_table_write" |
| 499 | + |
| 500 | + def collect(update): |
| 501 | + updates.append(update) |
| 502 | + event.set() |
| 503 | + |
| 504 | + df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) |
| 505 | + df2 = pd.DataFrame({"a": [7, 8, 9], "b": [10, 11, 12]}) |
| 506 | + x = client.write_table(df1, key=key) |
| 507 | + |
| 508 | + sub = client[key].subscribe() |
| 509 | + sub.new_data.add_callback(collect) |
| 510 | + with sub.start_in_thread(1): |
| 511 | + assert event.wait(timeout=5.0), "Timeout waiting for messages" |
| 512 | + actual = updates[0].data() |
| 513 | + assert_frame_equal(actual, df1) |
| 514 | + event.clear() |
| 515 | + x.write(df2) |
| 516 | + assert event.wait(timeout=5.0), "Timeout waiting for messages" |
| 517 | + assert not updates[1].append |
| 518 | + actual_updated = updates[1].data() |
| 519 | + assert_frame_equal(actual_updated, df2) |
| 520 | + |
| 521 | + |
| 522 | +def test_streaming_table_appends(tiled_websocket_context): |
| 523 | + context = tiled_websocket_context |
| 524 | + client = from_context(context) |
| 525 | + updates = [] |
| 526 | + event = threading.Event() |
| 527 | + key = "test_streaming_table_append" |
| 528 | + |
| 529 | + def collect(update): |
| 530 | + updates.append(update) |
| 531 | + event.set() |
| 532 | + |
| 533 | + df1 = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) |
| 534 | + df2 = pd.DataFrame({"a": [7, 8, 9], "b": [10, 11, 12]}) |
| 535 | + table1 = pyarrow.Table.from_pandas(df1, preserve_index=False) |
| 536 | + table2 = pyarrow.Table.from_pandas(df2, preserve_index=False) |
| 537 | + x = client.create_appendable_table(table1.schema, key=key) |
| 538 | + |
| 539 | + sub = client[key].subscribe() |
| 540 | + sub.new_data.add_callback(collect) |
| 541 | + with sub.start_in_thread(1): |
| 542 | + x.append_partition(0, table1) |
| 543 | + assert event.wait(timeout=5.0), "Timeout waiting for messages" |
| 544 | + assert updates[0].append |
| 545 | + streamed1 = updates[0].data() |
| 546 | + streamed1_pyarrow = pyarrow.Table.from_pandas(streamed1, preserve_index=False) |
| 547 | + assert streamed1_pyarrow == table1 |
| 548 | + event.clear() |
| 549 | + x.append_partition(0, table2) |
| 550 | + assert event.wait(timeout=5.0), "Timeout waiting for messages" |
| 551 | + assert updates[1].append |
| 552 | + streamed2 = updates[1].data() |
| 553 | + streamed2_pyarrow = pyarrow.Table.from_pandas(streamed2, preserve_index=False) |
| 554 | + assert streamed2_pyarrow == table2 |
| 555 | + streaming_combined = pyarrow.concat_tables( |
| 556 | + [streamed1_pyarrow, streamed2_pyarrow] |
| 557 | + ) |
| 558 | + expected_combined = pyarrow.concat_tables([table1, table2]) |
| 559 | + assert streaming_combined == expected_combined |
0 commit comments