Skip to content

Commit 5eb05c9

Browse files
Fix formatting and tests for 3.10 (hopefully)
1 parent 221dede commit 5eb05c9

File tree

2 files changed

+107
-62
lines changed

2 files changed

+107
-62
lines changed

nisystemlink/clients/dataframe/_data_frame_client.py

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pyarrow as pa # type: ignore
44
from collections.abc import Iterable
55
from io import BytesIO
6-
from typing import List, Optional, Union
6+
from typing import Any, List, Optional, Union
77

88
from nisystemlink.clients import core
99
from nisystemlink.clients.core._uplink._base_client import BaseClient
@@ -252,24 +252,19 @@ def get_table_data(
252252
"""
253253
...
254254

255-
@post(
256-
"tables/{id}/data",
257-
args=[Path, Body]
258-
)
255+
@post("tables/{id}/data", args=[Path, Body])
259256
def _append_table_data_json(
260257
self, id: str, data: models.AppendTableDataRequest
261-
) -> None:
262-
...
258+
) -> None: ...
263259

264260
@post(
265261
"tables/{id}/data",
266262
args=[Path, Body, Query("endOfData")],
267263
content_type="application/vnd.apache.arrow.stream",
268264
)
269265
def _append_table_data_arrow(
270-
self, id: str, data: Iterable[bytes], end_of_data: Optional[bool] = None
271-
) -> None:
272-
...
266+
self, id: str, data: bytes, end_of_data: Optional[bool] = None
267+
) -> None: ...
273268

274269
def append_table_data(
275270
self,
@@ -278,7 +273,7 @@ def append_table_data(
278273
Union[
279274
models.AppendTableDataRequest,
280275
models.DataFrame,
281-
Iterable["pa.RecordBatch"], # type: ignore[name-defined]
276+
Iterable["pa.RecordBatch"], # type: ignore[name-defined]
282277
]
283278
],
284279
*,
@@ -343,44 +338,34 @@ def append_table_data(
343338
"Iterable provided to data must yield pyarrow.RecordBatch objects."
344339
)
345340

346-
def _generate_body() -> Iterable[memoryview]:
347-
data_iter = iter(data)
348-
try:
349-
batch = next(data_iter)
350-
except StopIteration:
351-
return
341+
def _build_body() -> bytes:
352342
with BytesIO() as buf:
353343
options = pa.ipc.IpcWriteOptions(compression="zstd")
354-
writer = pa.ipc.new_stream(buf, batch.schema, options=options)
355-
356-
while True:
357-
writer.write_batch(batch)
358-
with buf.getbuffer() as view, view[0 : buf.tell()] as slice:
359-
yield slice
360-
buf.seek(0)
361-
try:
362-
batch = next(data_iter)
363-
except StopIteration:
364-
break
365-
366-
writer.close()
367-
with buf.getbuffer() as view:
368-
with view[0 : buf.tell()] as slice:
369-
yield slice
344+
with pa.ipc.new_stream(
345+
buf, first_batch.schema, options=options
346+
) as writer:
347+
writer.write_batch(first_batch)
348+
for batch in iterator:
349+
writer.write_batch(batch)
350+
return buf.getvalue()
370351

371352
try:
372353
self._append_table_data_arrow(
373354
id,
374-
_generate_body(),
355+
_build_body(),
375356
(end_of_data if end_of_data is not None else None),
376357
)
377358
except core.ApiException as ex:
378359
if ex.http_status_code == 400:
379360
wrap = True
380361
try:
381-
write_op = getattr(self.api_info().operations, "write_data", None)
382-
if write_op is not None and getattr(write_op, "version", 0) >= 2:
383-
# Service claims Arrow-capable write version; re-raise original exception
362+
write_op = getattr(
363+
self.api_info().operations, "write_data", None
364+
)
365+
if (
366+
write_op is not None
367+
and getattr(write_op, "version", 0) >= 2
368+
):
384369
wrap = False
385370
except Exception:
386371
pass

tests/integration/dataframe/test_dataframe.py

Lines changed: 85 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,17 @@ def test_tables(create_table):
8989
class TestDataFrame:
9090
def _new_single_int_table(self, create_table, column_name: str = "a") -> str:
9191
return create_table(
92-
CreateTableRequest(columns=[Column(name=column_name, data_type=DataType.Int64, column_type=ColumnType.Index)])
92+
CreateTableRequest(
93+
columns=[
94+
Column(
95+
name=column_name,
96+
data_type=DataType.Int64,
97+
column_type=ColumnType.Index,
98+
)
99+
]
100+
)
93101
)
102+
94103
def test__api_info__returns(self, client):
95104
response = client.api_info()
96105

@@ -620,78 +629,123 @@ def test__export_table_data__works(self, client: DataFrameClient, create_table):
620629
== b'"col1","col2","col3"\r\n1,2.5,6.5\r\n2,1.5,5.5\r\n3,2.5,7.5'
621630
)
622631

623-
def test__append_table_data__append_request_success(self, client: DataFrameClient, create_table):
632+
def test__append_table_data__append_request_success(
633+
self, client: DataFrameClient, create_table
634+
):
624635
table_id = self._new_single_int_table(create_table)
625636
frame = DataFrame(columns=["a"], data=[["1"], ["2"]])
626-
client.append_table_data(table_id, AppendTableDataRequest(frame=frame, end_of_data=True))
637+
client.append_table_data(
638+
table_id, AppendTableDataRequest(frame=frame, end_of_data=True)
639+
)
627640

628-
def test__append_table_data__append_request_with_end_of_data_argument_disallowed(self, client: DataFrameClient, create_table):
641+
def test__append_table_data__append_request_with_end_of_data_argument_disallowed(
642+
self, client: DataFrameClient, create_table
643+
):
629644
request = AppendTableDataRequest(end_of_data=True)
630-
with pytest.raises(ValueError, match="end_of_data must not be provided separately when passing an AppendTableDataRequest."):
631-
client.append_table_data(self._new_single_int_table(create_table), request, end_of_data=True)
645+
with pytest.raises(
646+
ValueError,
647+
match="end_of_data must not be provided separately when passing an AppendTableDataRequest.",
648+
):
649+
client.append_table_data(
650+
self._new_single_int_table(create_table), request, end_of_data=True
651+
)
632652

633-
def test__append_table_data__append_request_without_end_of_data_success(self, client: DataFrameClient, create_table):
653+
def test__append_table_data__append_request_without_end_of_data_success(
654+
self, client: DataFrameClient, create_table
655+
):
634656
table_id = self._new_single_int_table(create_table)
635657
frame = DataFrame(columns=["a"], data=[["7"], ["8"]])
636658
client.append_table_data(table_id, AppendTableDataRequest(frame=frame))
637659

638-
def test__append_table_data__accepts_dataframe_model(self, client: DataFrameClient, create_table):
660+
def test__append_table_data__accepts_dataframe_model(
661+
self, client: DataFrameClient, create_table
662+
):
639663
table_id = self._new_single_int_table(create_table)
640664
frame = DataFrame(columns=["a"], data=[["1"], ["2"]])
641665
client.append_table_data(table_id, frame, end_of_data=True)
642666

643-
def test__append_table_data__dataframe_without_end_of_data_success(self, client: DataFrameClient, create_table):
667+
def test__append_table_data__dataframe_without_end_of_data_success(
668+
self, client: DataFrameClient, create_table
669+
):
644670
table_id = self._new_single_int_table(create_table)
645671
frame = DataFrame(columns=["a"], data=[["10"], ["11"]])
646672
client.append_table_data(table_id, frame)
647673

648-
def test__append_table_data__none_without_end_of_data_raises(self, client: DataFrameClient, create_table):
674+
def test__append_table_data__none_without_end_of_data_raises(
675+
self, client: DataFrameClient, create_table
676+
):
649677
table_id = create_table(basic_table_model)
650-
with pytest.raises(ValueError, match="end_of_data must be provided when data is None"):
678+
with pytest.raises(
679+
ValueError, match="end_of_data must be provided when data is None"
680+
):
651681
client.append_table_data(table_id, None)
652682

653-
def test__append_table_data__flush_only_with_none(self, client: DataFrameClient, create_table):
683+
def test__append_table_data__flush_only_with_none(
684+
self, client: DataFrameClient, create_table
685+
):
654686
table_id = self._new_single_int_table(create_table)
655687
frame = DataFrame(columns=["a"], data=[["1"]])
656688
client.append_table_data(table_id, frame)
657689
client.append_table_data(table_id, None, end_of_data=True)
658690

659-
def test__append_table_data__arrow_ingestion_success(self, client: DataFrameClient, create_table):
691+
def test__append_table_data__arrow_ingestion_success(
692+
self, client: DataFrameClient, create_table
693+
):
660694
pa = pytest.importorskip("pyarrow")
661695
table_id = self._new_single_int_table(create_table)
662696
batch = pa.record_batch([pa.array([10, 11, 12])], names=["a"])
663697
client.append_table_data(table_id, [batch], end_of_data=True)
664698
with pytest.raises(ApiException):
665699
client.append_table_data(table_id, None, end_of_data=True)
666700

667-
def test__append_table_data__arrow_ingestion_with_end_of_data_query_param_false(self, client: DataFrameClient, create_table):
701+
def test__append_table_data__arrow_ingestion_with_end_of_data_query_param_false(
702+
self, client: DataFrameClient, create_table
703+
):
668704
pa = pytest.importorskip("pyarrow")
669705
table_id = self._new_single_int_table(create_table)
670706
batch1 = pa.record_batch([pa.array([4, 5, 6])], names=["a"])
671707
client.append_table_data(table_id, [batch1], end_of_data=False)
672708
batch2 = pa.record_batch([pa.array([7, 8])], names=["a"])
673709
client.append_table_data(table_id, [batch2], end_of_data=True)
674710

675-
def test__append_table_data__empty_iterator_requires_end_of_data(self, client: DataFrameClient, create_table):
711+
def test__append_table_data__empty_iterator_requires_end_of_data(
712+
self, client: DataFrameClient, create_table
713+
):
676714
table_id = create_table(basic_table_model)
677-
with pytest.raises(ValueError, match="end_of_data must be provided when data iterator is empty."):
715+
with pytest.raises(
716+
ValueError,
717+
match="end_of_data must be provided when data iterator is empty.",
718+
):
678719
client.append_table_data(table_id, [])
679720
client.append_table_data(table_id, [], end_of_data=True)
680721

681-
def test__append_table_data__arrow_iterable_with_non_recordbatch_elements_raises(self, client: DataFrameClient, create_table):
722+
def test__append_table_data__arrow_iterable_with_non_recordbatch_elements_raises(
723+
self, client: DataFrameClient, create_table
724+
):
682725
pytest.importorskip("pyarrow")
683726
table_id = create_table(basic_table_model)
684-
with pytest.raises(ValueError, match="Iterable provided to data must yield pyarrow.RecordBatch objects."):
727+
with pytest.raises(
728+
ValueError,
729+
match="Iterable provided to data must yield pyarrow.RecordBatch objects.",
730+
):
685731
client.append_table_data(table_id, [1, 2, 3])
686732

687-
def test__append_table_data__arrow_iterable_without_pyarrow_raises_runtime_error(self, client: DataFrameClient, create_table, monkeypatch):
733+
def test__append_table_data__arrow_iterable_without_pyarrow_raises_runtime_error(
734+
self, client: DataFrameClient, create_table, monkeypatch
735+
):
688736
import nisystemlink.clients.dataframe._data_frame_client as df_module
737+
689738
monkeypatch.setattr(df_module, "pa", None)
690739
table_id = create_table(basic_table_model)
691-
with pytest.raises(RuntimeError, match="pyarrow is not installed. Install to stream RecordBatches."):
740+
with pytest.raises(
741+
RuntimeError,
742+
match="pyarrow is not installed. Install to stream RecordBatches.",
743+
):
692744
client.append_table_data(table_id, [object()])
693745

694-
def test__append_table_data__arrow_ingestion_400_unsupported(self, client: DataFrameClient):
746+
def test__append_table_data__arrow_ingestion_400_unsupported(
747+
self, client: DataFrameClient
748+
):
695749
pa = pytest.importorskip("pyarrow")
696750
table_id = "mock_table_id"
697751
bad_batch = pa.record_batch([pa.array([1, 2, 3])], names=["b"])
@@ -722,7 +776,9 @@ def test__append_table_data__arrow_ingestion_400_unsupported(self, client: DataF
722776

723777
assert "Arrow ingestion request was rejected" in str(excinfo.value)
724778

725-
def test__append_table_data__arrow_ingestion_400_supported_passthrough(self, client: DataFrameClient):
779+
def test__append_table_data__arrow_ingestion_400_supported_passthrough(
780+
self, client: DataFrameClient
781+
):
726782
pa = pytest.importorskip("pyarrow")
727783
table_id = "mock_table_id"
728784
bad_batch = pa.record_batch([pa.array([1, 2, 3])], names=["b"])
@@ -753,21 +809,25 @@ def test__append_table_data__arrow_ingestion_400_supported_passthrough(self, cli
753809

754810
assert "Arrow ingestion request was rejected" not in str(excinfo.value)
755811

756-
def test__append_table_data__arrow_ingestion_non_400_passthrough(self, client: DataFrameClient):
812+
def test__append_table_data__arrow_ingestion_non_400_passthrough(
813+
self, client: DataFrameClient
814+
):
757815
pa = pytest.importorskip("pyarrow")
758816
table_id = "mock_table_id"
759817
batch = pa.record_batch([pa.array([1, 2, 3])], names=["a"])
760818
with responses.RequestsMock() as rsps:
761819
rsps.add(
762820
responses.POST,
763821
f"{client.session.base_url}tables/{table_id}/data",
764-
status=409
822+
status=409,
765823
)
766824
with pytest.raises(ApiException) as excinfo:
767825
client.append_table_data(table_id, [batch], end_of_data=True)
768826
assert "Arrow ingestion request was rejected" not in str(excinfo.value)
769827

770-
def test__append_table_data__unsupported_type_raises(self, client: DataFrameClient, create_table):
828+
def test__append_table_data__unsupported_type_raises(
829+
self, client: DataFrameClient, create_table
830+
):
771831
table_id = create_table(basic_table_model)
772832
with pytest.raises(ValueError, match="Unsupported type"):
773833
client.append_table_data(table_id, 123)

0 commit comments

Comments
 (0)