Skip to content

Commit 0e056f5

Browse files
authored
Merge branch 'main' into gabrys/bundle-neptune-api
2 parents 139e8b6 + c21679d commit 0e056f5

23 files changed

+63
-47
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,28 @@
11
exclude: ^src/neptune_query/generated/
22
repos:
33
- repo: https://github.com/pre-commit/pre-commit-hooks
4-
rev: v4.3.0
4+
rev: v6.0.0
55
hooks:
66
- id: check-yaml
77
- id: end-of-file-fixer
88
- id: trailing-whitespace
99
- repo: https://github.com/pycqa/isort
10-
rev: 5.13.2
10+
rev: 7.0.0
1111
hooks:
1212
- id: isort
1313
args: [--settings-path, pyproject.toml]
1414
- repo: https://github.com/psf/black
15-
rev: 22.6.0
15+
rev: 25.12.0
1616
hooks:
1717
- id: black
1818
args: [--config, pyproject.toml]
1919
- repo: https://github.com/pycqa/flake8
2020
rev: 7.0.0
2121
hooks:
2222
- id: flake8
23-
entry: pflake8
24-
# We use a custom version of flake8 that includes a fix for Python 3.12 compatibility
25-
additional_dependencies: ["git+https://github.com/neptune-ai/pyproject-flake8.git@7.0.0"]
23+
additional_dependencies: [Flake8-pyproject==1.2.4]
2624
- repo: https://github.com/Lucas-C/pre-commit-hooks
27-
rev: v1.5.4
25+
rev: v1.5.5
2826
hooks:
2927
- id: insert-license
3028
files: ^src/neptune_query.*[^/]+\.py$

src/neptune_query/filters.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def any(*filters: "BaseAttributeFilter") -> "BaseAttributeFilter":
6363
return _AttributeFilterAlternative(filters=filters)
6464

6565
@abc.abstractmethod
66-
def _to_internal(self) -> _filters._BaseAttributeFilter:
67-
...
66+
def _to_internal(self) -> _filters._BaseAttributeFilter: ...
6867

6968

7069
@dataclass

src/neptune_query/internal/composition/fetch_metric_buckets.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@
6464
class _FetchInChunksProtocol(Protocol):
6565
def __call__(
6666
self, x_range: Optional[tuple[float, float]], bucket_limit: int
67-
) -> dict[RunAttributeDefinition, list[TimeseriesBucket]]:
68-
...
67+
) -> dict[RunAttributeDefinition, list[TimeseriesBucket]]: ...
6968

7069

7170
def fetch_metric_buckets(
@@ -252,9 +251,10 @@ def _compute_global_x_range(fetch_in_chunks: _FetchInChunksProtocol) -> Optional
252251
return x_range[0], x_range[1]
253252

254253

255-
def _update_range(
256-
current_range: tuple[Optional[float], Optional[float]], bucket: TimeseriesBucket
257-
) -> tuple[Optional[float], Optional[float],]:
254+
def _update_range(current_range: tuple[Optional[float], Optional[float]], bucket: TimeseriesBucket) -> tuple[
255+
Optional[float],
256+
Optional[float],
257+
]:
258258
# We're including from_x and to_x because some buckets might hold only non-finite points,
259259
# in which case first_x and last_x are None.
260260
candidates = [bucket.first_x, bucket.last_x, bucket.from_x, bucket.to_x] + list(current_range)

src/neptune_query/internal/composition/fetch_metrics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,9 @@ def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
181181
),
182182
)
183183

184-
results: Generator[
185-
dict[identifiers.RunAttributeDefinition, list[FloatPointValue]], None, None
186-
] = concurrency.gather_results(output)
184+
results: Generator[dict[identifiers.RunAttributeDefinition, list[FloatPointValue]], None, None] = (
185+
concurrency.gather_results(output)
186+
)
187187

188188
metrics_data: dict[identifiers.RunAttributeDefinition, list[FloatPointValue]] = {}
189189
for result in results:

src/neptune_query/internal/filters.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ def any(filters: list["_BaseAttributeFilter"]) -> "_BaseAttributeFilter":
4949
@abc.abstractmethod
5050
def transform(
5151
self, map_attribute_filter: Callable[["_AttributeFilter"], "_AttributeFilter"]
52-
) -> "_BaseAttributeFilter":
53-
...
52+
) -> "_BaseAttributeFilter": ...
5453

5554

5655
@dataclass
@@ -351,8 +350,7 @@ def name_eq(name: str) -> "_Filter":
351350
return _Filter.eq(name_attribute, name)
352351

353352
@abc.abstractmethod
354-
def to_query(self) -> str:
355-
...
353+
def to_query(self) -> str: ...
356354

357355
def __str__(self) -> str:
358356
return self.to_query()

src/neptune_query/internal/retrieval/global_search.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,17 @@ def fetch_global_entries(
117117
type=AttributeTypeDTO(
118118
map_attribute_type_python_to_backend(sort_by.type) if sort_by.type is not None else "string"
119119
),
120-
aggregation_mode=QueryLeaderboardParamsFieldDTOAggregationMode(sort_by.aggregation)
121-
if sort_by.aggregation is not None
122-
else UNSET,
120+
aggregation_mode=(
121+
QueryLeaderboardParamsFieldDTOAggregationMode(sort_by.aggregation)
122+
if sort_by.aggregation is not None
123+
else UNSET
124+
),
125+
),
126+
dir_=(
127+
QueryLeaderboardParamsSortingParamsDTODir.ASCENDING
128+
if sort_direction == "asc"
129+
else QueryLeaderboardParamsSortingParamsDTODir.DESCENDING
123130
),
124-
dir_=QueryLeaderboardParamsSortingParamsDTODir.ASCENDING
125-
if sort_direction == "asc"
126-
else QueryLeaderboardParamsSortingParamsDTODir.DESCENDING,
127131
),
128132
),
129133
)

src/neptune_query/internal/retrieval/metric_buckets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from ..query_metadata_context import with_neptune_client_metadata
4242
from . import retry
4343
from .search import ContainerType
44-
from .util import ProtobufPayload
44+
from .util import body_from_protobuf
4545

4646
logger = get_logger()
4747

@@ -153,7 +153,7 @@ def fetch_time_series_buckets(
153153
call_api = retry.handle_errors_default(with_neptune_client_metadata(get_timeseries_buckets_proto.sync_detailed))
154154
response = call_api(
155155
client=client,
156-
body=ProtobufPayload(request_object),
156+
body=body_from_protobuf(request_object),
157157
)
158158

159159
logger.debug(

src/neptune_query/internal/retrieval/metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
logger = get_logger()
4040

4141
# Tuples are used here to enhance performance
42-
FloatPointValue = tuple[float, float, float, bool, float]
42+
FloatPointValue = tuple[int, float, float, bool, float]
4343
(
4444
TimestampIndex,
4545
StepIndex,

src/neptune_query/internal/retrieval/search.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,10 @@ class ContainerType(Enum):
7171

7272
class SysIdLabel(Protocol):
7373
@property
74-
def sys_id(self) -> identifiers.SysId:
75-
...
74+
def sys_id(self) -> identifiers.SysId: ...
7675

7776
@property
78-
def label(self) -> str:
79-
...
77+
def label(self) -> str: ...
8078

8179

8280
@dataclass(frozen=True)
@@ -136,8 +134,7 @@ def __call__(
136134
limit: Optional[int] = None,
137135
batch_size: int = env.NEPTUNE_QUERY_SYS_ATTRS_BATCH_SIZE.get(),
138136
container_type: ContainerType = ContainerType.EXPERIMENT,
139-
) -> Generator[util.Page[T], None, None]:
140-
...
137+
) -> Generator[util.Page[T], None, None]: ...
141138

142139

143140
def _create_fetch_sys_attrs(

src/neptune_query/internal/retrieval/util.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,19 @@ def fetch_pages(
5757
page_params = make_new_page_params(page_params, data)
5858

5959

60-
class ProtobufPayload(File):
61-
"""A version of the neptune_api.types.File class that uses a protobuf message as payload"""
60+
class ReusableFile(File):
61+
"""A File that recreates its payload on each access to support retries."""
6262

6363
@property
6464
def payload(self) -> BinaryIO:
6565
return self.get_payload()
6666

6767
@payload.setter
68-
def payload(self, message: Message) -> None:
69-
self.get_payload = lambda: BytesIO(message.SerializeToString())
68+
def payload(self, payload: BinaryIO) -> None:
69+
stored_payload = payload.read()
70+
self.get_payload = lambda: BytesIO(stored_payload)
71+
72+
73+
def body_from_protobuf(message: Message) -> ReusableFile:
74+
"""Generate a ReusableFile from a protobuf message."""
75+
return ReusableFile(payload=BytesIO(message.SerializeToString()))

0 commit comments

Comments
 (0)