diff --git a/src/neptune_query/internal/retrieval/metric_buckets.py b/src/neptune_query/internal/retrieval/metric_buckets.py index 0d8bcdef..c4bb13d9 100644 --- a/src/neptune_query/internal/retrieval/metric_buckets.py +++ b/src/neptune_query/internal/retrieval/metric_buckets.py @@ -14,7 +14,6 @@ # limitations under the License. from dataclasses import dataclass -from io import BytesIO from typing import ( Iterable, Literal, @@ -34,13 +33,13 @@ XSteps, ) from neptune_api.proto.neptune_pb.api.v1.model.series_values_pb2 import ProtoTimeseriesBucketsDTO -from neptune_api.types import File from ..identifiers import RunAttributeDefinition from ..logger import get_logger from ..query_metadata_context import with_neptune_client_metadata from . import retry from .search import ContainerType +from .util import ProtobufPayload logger = get_logger() @@ -147,7 +146,7 @@ def fetch_time_series_buckets( call_api = retry.handle_errors_default(with_neptune_client_metadata(get_timeseries_buckets_proto.sync_detailed)) response = call_api( client=client, - body=File(payload=BytesIO(request_object.SerializeToString())), + body=ProtobufPayload(request_object), ) logger.debug( diff --git a/src/neptune_query/internal/retrieval/util.py b/src/neptune_query/internal/retrieval/util.py index e1b09da5..bdc9480d 100644 --- a/src/neptune_query/internal/retrieval/util.py +++ b/src/neptune_query/internal/retrieval/util.py @@ -16,6 +16,7 @@ from __future__ import annotations from dataclasses import dataclass +from io import BytesIO from typing import ( Any, Callable, @@ -25,7 +26,9 @@ TypeVar, ) +from google.protobuf.message import Message from neptune_api import AuthenticatedClient +from neptune_api.types import File T = TypeVar("T") R = TypeVar("R") @@ -50,3 +53,15 @@ def fetch_pages( page = process_page(data) yield page page_params = make_new_page_params(page_params, data) + + +class ProtobufPayload(File): + """A version of the neptune_api.types.File class that uses a protobuf message as payload""" + + @property + def payload(self) -> BytesIO: + return self.get_payload() + + @payload.setter + def payload(self, message: Message) -> None: + self.get_payload = lambda: BytesIO(message.SerializeToString()) diff --git a/tests/unit/internal/retrieval/test_protobuf_payload.py b/tests/unit/internal/retrieval/test_protobuf_payload.py new file mode 100644 index 00000000..72554209 --- /dev/null +++ b/tests/unit/internal/retrieval/test_protobuf_payload.py @@ -0,0 +1,39 @@ +from io import BytesIO +from unittest import mock + +from neptune_api.proto.protobuf_v4plus.neptune_pb.api.v1.model.requests_pb2 import ( + ProtoCustomExpression, + ProtoGetTimeseriesBucketsRequest, + ProtoScale, + ProtoView, +) +from neptune_api.types import File + +from neptune_query.internal.retrieval.util import ProtobufPayload + + +@mock.patch("neptune_query.internal.retrieval.util.BytesIO", wraps=BytesIO) +def test_bytesio_recreation_on_retry(mock_bytesio): + """Test that BytesIO instance is recreated when the API call is retried.""" + + file = ProtobufPayload( + ProtoGetTimeseriesBucketsRequest( + expressions=[ProtoCustomExpression(requestId="0123", customYFormula="${abc}")], + view=ProtoView(xScale=ProtoScale.linear), + ) + ) + + # Verify file is neptune_api.types.File as expected + assert isinstance(file, File) + + # Read file's payload two times to simulate two API calls: + read1 = file.payload.read() + read2 = file.payload.read() + + # Verify that both reads return the same content + # If the same BytesIO instance was used, the second read would return b'' + assert read1 == read2 + assert read2 != b"" + + # Verify BytesIO was called twice + assert mock_bytesio.call_count == 2