Skip to content

Commit 1b1f0d1

Browse files
committed
feat: directly fetch run_attribute_definitions
1 parent e315b32 commit 1b1f0d1

File tree

7 files changed

+210
-102
lines changed

7 files changed

+210
-102
lines changed

src/neptune_query/internal/composition/attribute_components.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
split,
3838
util,
3939
)
40+
from .run_attributes import fetch_run_attribute_definitions
4041

4142

4243
def fetch_attribute_definitions_split(
@@ -163,3 +164,29 @@ def fetch_attribute_values_split(
163164
downstream=downstream,
164165
),
165166
)
167+
168+
169+
def fetch_run_attribute_definitions_split(
170+
client: AuthenticatedClient,
171+
project_identifier: identifiers.ProjectIdentifier,
172+
attribute_filter: filters._BaseAttributeFilter,
173+
executor: Executor,
174+
fetch_attribute_definitions_executor: Executor,
175+
sys_ids: list[identifiers.SysId],
176+
downstream: Callable[[util.Page[identifiers.RunAttributeDefinition]], concurrency.OUT],
177+
) -> concurrency.OUT:
178+
return concurrency.generate_concurrently(
179+
items=split.split_sys_ids(sys_ids),
180+
executor=executor,
181+
downstream=lambda sys_ids_split: concurrency.generate_concurrently(
182+
fetch_run_attribute_definitions(
183+
client=client,
184+
project_identifier=project_identifier,
185+
run_identifiers=[identifiers.RunIdentifier(project_identifier, sys_id) for sys_id in sys_ids_split],
186+
attribute_filter=attribute_filter,
187+
executor=fetch_attribute_definitions_executor,
188+
),
189+
executor=executor,
190+
downstream=lambda run_definitions: downstream(run_definitions),
191+
),
192+
)

src/neptune_query/internal/composition/attributes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..composition import concurrency
3333
from ..retrieval import attribute_definitions as att_defs
3434
from ..retrieval import util
35+
from ..retrieval.attribute_filter import split_attribute_filters
3536
from ..retrieval.attribute_types import TYPE_AGGREGATIONS
3637

3738

@@ -126,7 +127,7 @@ def go_fetch_single(
126127
batch_size=batch_size,
127128
)
128129

129-
filters_ = att_defs.split_attribute_filters(attribute_filter)
130+
filters_ = split_attribute_filters(attribute_filter)
130131

131132
output = concurrency.generate_concurrently(
132133
items=(filter_ for filter_ in filters_),

src/neptune_query/internal/composition/fetch_metrics.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
type_inference,
3131
validation,
3232
)
33-
from ..composition.attribute_components import fetch_attribute_definitions_split
33+
from ..composition.attribute_components import fetch_run_attribute_definitions_split
3434
from ..context import (
3535
Context,
3636
get_context,
@@ -145,25 +145,15 @@ def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
145145
output = concurrency.generate_concurrently(
146146
items=go_fetch_sys_attrs(),
147147
executor=executor,
148-
downstream=lambda sys_ids: fetch_attribute_definitions_split(
148+
downstream=lambda sys_ids: fetch_run_attribute_definitions_split(
149149
client=client,
150150
project_identifier=project_identifier,
151151
attribute_filter=attributes,
152152
executor=executor,
153153
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
154154
sys_ids=sys_ids,
155-
downstream=lambda sys_ids_split, definitions_page: concurrency.generate_concurrently(
156-
items=split.split_series_attributes(
157-
items=(
158-
identifiers.RunAttributeDefinition(
159-
run_identifier=identifiers.RunIdentifier(project_identifier, sys_id),
160-
attribute_definition=definition,
161-
)
162-
for sys_id in sys_ids_split
163-
for definition in definitions_page.items
164-
if definition.type == "float_series"
165-
)
166-
),
155+
downstream=lambda run_definitions_page: concurrency.generate_concurrently(
156+
items=split.split_series_attributes(items=run_definitions_page.items),
167157
executor=executor,
168158
downstream=lambda run_attribute_definitions_split: concurrency.return_value(
169159
fetch_multiple_series_values(

src/neptune_query/internal/composition/run_attributes.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@
1414
# limitations under the License.
1515

1616
from concurrent.futures import Executor
17-
from dataclasses import dataclass
1817
from typing import (
1918
Generator,
2019
Iterable,
21-
Literal,
2220
Optional,
2321
)
2422

@@ -30,50 +28,50 @@
3028
identifiers,
3129
)
3230
from ..composition import concurrency
33-
from ..retrieval import attribute_definitions as att_defs
31+
from ..retrieval import attribute_values as att_vals
3432
from ..retrieval import util
35-
from ..retrieval.attribute_types import TYPE_AGGREGATIONS
33+
from ..retrieval.attribute_filter import split_attribute_filters
3634

3735

3836
def fetch_run_attribute_definitions(
3937
client: AuthenticatedClient,
40-
project_identifiers: Iterable[identifiers.ProjectIdentifier],
38+
project_identifier: identifiers.ProjectIdentifier,
4139
run_identifiers: Optional[Iterable[identifiers.RunIdentifier]],
4240
attribute_filter: filters._BaseAttributeFilter,
4341
executor: Executor,
4442
batch_size: int = env.NEPTUNE_QUERY_ATTRIBUTE_DEFINITIONS_BATCH_SIZE.get(),
4543
) -> Generator[util.Page[identifiers.RunAttributeDefinition], None, None]:
4644
pages_filters = _fetch_run_attribute_definitions(
47-
client, project_identifiers, run_identifiers, attribute_filter, batch_size, executor
45+
client, project_identifier, run_identifiers, attribute_filter, batch_size, executor
4846
)
4947

5048
seen_items: set[identifiers.RunAttributeDefinition] = set()
51-
for page, filter_ in pages_filters:
49+
for page in pages_filters:
5250
new_items = [item for item in page.items if item not in seen_items]
5351
seen_items.update(new_items)
5452
yield util.Page(items=new_items)
5553

5654

5755
def _fetch_run_attribute_definitions(
5856
client: AuthenticatedClient,
59-
project_identifiers: Iterable[identifiers.ProjectIdentifier],
57+
project_identifier: identifiers.ProjectIdentifier,
6058
run_identifiers: Optional[Iterable[identifiers.RunIdentifier]],
6159
attribute_filter: filters._BaseAttributeFilter,
6260
batch_size: int,
6361
executor: Executor,
6462
) -> Generator[util.Page[identifiers.RunAttributeDefinition], None, None]:
6563
def go_fetch_single(
6664
filter_: filters._AttributeFilter,
67-
) -> Generator[util.Page[identifiers.AttributeDefinition], None, None]:
68-
return att_defs.fetch_attribute_definitions_single_filter(
65+
) -> Generator[util.Page[identifiers.RunAttributeDefinition], None, None]:
66+
return att_vals.fetch_run_attribute_definitions_single_filter(
6967
client=client,
70-
project_identifiers=project_identifiers,
68+
project_identifier=project_identifier,
7169
run_identifiers=run_identifiers,
7270
attribute_filter=filter_,
7371
batch_size=batch_size,
7472
)
7573

76-
filters_ = att_defs.split_attribute_filters(attribute_filter)
74+
filters_ = split_attribute_filters(attribute_filter)
7775

7876
output = concurrency.generate_concurrently(
7977
items=(filter_ for filter_ in filters_),

src/neptune_query/internal/retrieval/attribute_definitions.py

Lines changed: 3 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import functools as ft
16-
import itertools as it
17-
import re
1816
from typing import (
1917
Any,
2018
Generator,
2119
Iterable,
2220
Optional,
23-
Union,
2421
)
2522

2623
from neptune_api.api.retrieval import query_attribute_definitions_within_project
@@ -40,17 +37,7 @@
4037
from ..retrieval import attribute_types as types # noqa: E402
4138
from ..retrieval import util # noqa: E402
4239
from ..retrieval import retry
43-
44-
45-
def split_attribute_filters(
46-
_attribute_filter: filters._BaseAttributeFilter,
47-
) -> list[filters._AttributeFilter]:
48-
if isinstance(_attribute_filter, filters._AttributeFilter):
49-
return [_attribute_filter]
50-
elif isinstance(_attribute_filter, filters._AttributeFilterAlternative):
51-
return list(it.chain.from_iterable(split_attribute_filters(child) for child in _attribute_filter.filters))
52-
else:
53-
raise RuntimeError(f"Unexpected filter type: {type(_attribute_filter)}")
40+
from .attribute_filter import transform_attribute_filter_into_params
5441

5542

5643
def fetch_attribute_definitions_single_filter(
@@ -69,34 +56,8 @@ def fetch_attribute_definitions_single_filter(
6956
if run_identifiers is not None:
7057
params["experimentIdsFilter"] = [str(e) for e in run_identifiers]
7158

72-
# Convert name_eq to an additional condition added to each of must_match_any alternatives.
73-
name_regexes = None
74-
if attribute_filter.name_eq is not None:
75-
name_regexes = _escape_name_eq(_variants_to_list(attribute_filter.name_eq))
76-
77-
if attribute_filter.must_match_any is not None:
78-
attribute_name_filter_dtos = []
79-
for alternative in attribute_filter.must_match_any:
80-
attribute_name_filter_dto = {}
81-
must_match_regexes = _union_options([name_regexes, alternative.must_match_regexes])
82-
if must_match_regexes is not None:
83-
attribute_name_filter_dto["mustMatchRegexes"] = must_match_regexes
84-
if alternative.must_not_match_regexes is not None:
85-
attribute_name_filter_dto["mustNotMatchRegexes"] = alternative.must_not_match_regexes
86-
if attribute_name_filter_dto:
87-
attribute_name_filter_dtos.append(attribute_name_filter_dto)
88-
params["attributeNameFilter"]["mustMatchAny"] = attribute_name_filter_dtos
89-
90-
elif name_regexes is not None:
91-
params["attributeNameFilter"]["mustMatchAny"] = [{"mustMatchRegexes": name_regexes}]
92-
93-
attribute_types = _variants_to_list(attribute_filter.type_in)
94-
if attribute_types is not None:
95-
params["attributeFilter"] = [
96-
{"attributeType": types.map_attribute_type_python_to_backend(_type)} for _type in attribute_types
97-
]
98-
99-
# note: attribute_filter.aggregations is intentionally ignored
59+
attribute_filter_params = transform_attribute_filter_into_params(attribute_filter)
60+
params.update(attribute_filter_params)
10061

10162
return util.fetch_pages(
10263
client=client,
@@ -149,36 +110,3 @@ def _make_new_attribute_definitions_page_params(
149110

150111
params["nextPage"]["nextPageToken"] = next_page_token
151112
return params
152-
153-
154-
def _escape_name_eq(names: Optional[list[str]]) -> Optional[list[str]]:
155-
if names is None:
156-
return None
157-
158-
escaped = [f"{re.escape(name)}" for name in names]
159-
160-
if len(escaped) == 1:
161-
return [f"^{escaped[0]}$"]
162-
else:
163-
joined = "|".join(escaped)
164-
return [f"^({joined})$"]
165-
166-
167-
def _variants_to_list(param: Union[str, Iterable[str], None]) -> Optional[list[str]]:
168-
if param is None:
169-
return None
170-
if isinstance(param, str):
171-
return [param]
172-
return list(param)
173-
174-
175-
def _union_options(options: list[Optional[list[str]]]) -> Optional[list[str]]:
176-
result = None
177-
178-
for option in options:
179-
if option is not None:
180-
if result is None:
181-
result = []
182-
result.extend(option)
183-
184-
return result
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#
2+
# Copyright (c) 2025, Neptune Labs Sp. z o.o.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import itertools as it
16+
import re
17+
from typing import (
18+
Any,
19+
Iterable,
20+
Optional,
21+
Union,
22+
)
23+
24+
from .. import filters # noqa: E402
25+
from ..retrieval import attribute_types as types # noqa: E402
26+
27+
28+
def split_attribute_filters(
29+
_attribute_filter: filters._BaseAttributeFilter,
30+
) -> list[filters._AttributeFilter]:
31+
if isinstance(_attribute_filter, filters._AttributeFilter):
32+
return [_attribute_filter]
33+
elif isinstance(_attribute_filter, filters._AttributeFilterAlternative):
34+
return list(it.chain.from_iterable(split_attribute_filters(child) for child in _attribute_filter.filters))
35+
else:
36+
raise RuntimeError(f"Unexpected filter type: {type(_attribute_filter)}")
37+
38+
39+
def transform_attribute_filter_into_params(
40+
attribute_filter: filters._AttributeFilter,
41+
) -> dict[str, Any]:
42+
params: dict[str, Any] = {
43+
"attributeNameFilter": {},
44+
}
45+
46+
name_regexes = None
47+
if attribute_filter.name_eq is not None:
48+
name_regexes = _escape_name_eq(_variants_to_list(attribute_filter.name_eq))
49+
50+
if attribute_filter.must_match_any is not None:
51+
attribute_name_filter_dtos = []
52+
for alternative in attribute_filter.must_match_any:
53+
attribute_name_filter_dto = {}
54+
must_match_regexes = _union_options([name_regexes, alternative.must_match_regexes])
55+
if must_match_regexes is not None:
56+
attribute_name_filter_dto["mustMatchRegexes"] = must_match_regexes
57+
if alternative.must_not_match_regexes is not None:
58+
attribute_name_filter_dto["mustNotMatchRegexes"] = alternative.must_not_match_regexes
59+
if attribute_name_filter_dto:
60+
attribute_name_filter_dtos.append(attribute_name_filter_dto)
61+
params["attributeNameFilter"]["mustMatchAny"] = attribute_name_filter_dtos
62+
63+
elif name_regexes is not None:
64+
params["attributeNameFilter"]["mustMatchAny"] = [{"mustMatchRegexes": name_regexes}]
65+
66+
attribute_types = _variants_to_list(attribute_filter.type_in)
67+
if attribute_types is not None:
68+
params["attributeFilter"] = [
69+
{"attributeType": types.map_attribute_type_python_to_backend(_type)} for _type in attribute_types
70+
]
71+
72+
# note: attribute_filter.aggregations is intentionally ignored
73+
74+
return params
75+
76+
77+
def _escape_name_eq(names: Optional[list[str]]) -> Optional[list[str]]:
78+
if names is None:
79+
return None
80+
81+
escaped = [f"{re.escape(name)}" for name in names]
82+
83+
if len(escaped) == 1:
84+
return [f"^{escaped[0]}$"]
85+
else:
86+
joined = "|".join(escaped)
87+
return [f"^({joined})$"]
88+
89+
90+
def _variants_to_list(param: Union[str, Iterable[str], None]) -> Optional[list[str]]:
91+
if param is None:
92+
return None
93+
if isinstance(param, str):
94+
return [param]
95+
return list(param)
96+
97+
98+
def _union_options(options: list[Optional[list[str]]]) -> Optional[list[str]]:
99+
result = None
100+
101+
for option in options:
102+
if option is not None:
103+
if result is None:
104+
result = []
105+
result.extend(option)
106+
107+
return result

0 commit comments

Comments
 (0)