Skip to content

Commit 5701269

Browse files
Optimize table fetches for explicit attribute lists via entries-search projection (#210)
Co-authored-by: Patryk Gała <patryk.gala@openai.com>
1 parent 4576a0a commit 5701269

File tree

7 files changed

+332
-44
lines changed

7 files changed

+332
-44
lines changed

src/neptune_query/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def fetch_experiments_table(
304304
sort_direction=sort_direction,
305305
limit=limit,
306306
type_suffix_in_column_names=type_suffix_in_column_names,
307+
exact_attribute_names=attributes if isinstance(attributes, list) else None,
307308
container_type=_search.ContainerType.EXPERIMENT,
308309
)
309310

src/neptune_query/internal/composition/fetch_table.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
from ...exceptions import NeptuneUserError
2525
from .. import client as _client
2626
from .. import context as _context
27-
from .. import identifiers
27+
from .. import (
28+
env,
29+
identifiers,
30+
)
2831
from ..composition import attribute_components as _components
2932
from ..composition import (
3033
concurrency,
@@ -61,6 +64,7 @@ def fetch_table(
6164
sort_direction: Literal["asc", "desc"],
6265
limit: Optional[int],
6366
type_suffix_in_column_names: bool,
67+
exact_attribute_names: Optional[list[str]] = None,
6468
context: Optional[_context.Context] = None,
6569
container_type: search.ContainerType,
6670
) -> pd.DataFrame:
@@ -93,6 +97,30 @@ def fetch_table(
9397
sort_by = sort_by_inference_result.result
9498
sort_by_inference_result.emit_warnings()
9599

100+
if exact_attribute_names is not None:
101+
exact_attribute_names_set = set(exact_attribute_names)
102+
if len(exact_attribute_names) <= env.NEPTUNE_QUERY_ENTRIES_SEARCH_MAX_PROJECTION_ATTRIBUTES.get():
103+
table_rows: list[TableRow] = []
104+
for page in search.fetch_table_rows_exact_attributes(
105+
client=client,
106+
project_identifier=project_identifier,
107+
filter_=filter_,
108+
requested_attribute_names=exact_attribute_names_set,
109+
sort_by=sort_by,
110+
sort_direction=_sort_direction,
111+
limit=limit,
112+
container_type=container_type,
113+
):
114+
for item in page.items:
115+
table_rows.append(
116+
TableRow(values=item.values, label=item.label, project_identifier=project_identifier)
117+
)
118+
return create_runs_table(
119+
table_rows=table_rows,
120+
type_suffix_in_column_names=type_suffix_in_column_names,
121+
container_type=container_type,
122+
)
123+
96124
sys_id_label_mapping: dict[identifiers.SysId, str] = {}
97125
result_by_id: dict[identifiers.SysId, list[att_vals.AttributeValue]] = {}
98126

src/neptune_query/internal/env.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"NEPTUNE_PROJECT",
2828
"NEPTUNE_QUERY_ATTRIBUTE_DEFINITIONS_BATCH_SIZE",
2929
"NEPTUNE_QUERY_ATTRIBUTE_VALUES_BATCH_SIZE",
30+
"NEPTUNE_QUERY_ENTRIES_SEARCH_MAX_PROJECTION_ATTRIBUTES",
3031
"NEPTUNE_QUERY_FILES_MAX_CONCURRENCY",
3132
"NEPTUNE_QUERY_FILES_TIMEOUT",
3233
"NEPTUNE_QUERY_MAX_REQUEST_SIZE",
@@ -95,6 +96,9 @@ def _map_logging_level(value: str) -> str:
9596
"NEPTUNE_QUERY_ATTRIBUTE_DEFINITIONS_BATCH_SIZE", int, 10_000
9697
)
9798
NEPTUNE_QUERY_ATTRIBUTE_VALUES_BATCH_SIZE = EnvVariable[int]("NEPTUNE_QUERY_ATTRIBUTE_VALUES_BATCH_SIZE", int, 10_000)
99+
NEPTUNE_QUERY_ENTRIES_SEARCH_MAX_PROJECTION_ATTRIBUTES = EnvVariable[int](
100+
"NEPTUNE_QUERY_ENTRIES_SEARCH_MAX_PROJECTION_ATTRIBUTES", int, 1000
101+
)
98102
NEPTUNE_QUERY_FILES_BATCH_SIZE = EnvVariable[int]("NEPTUNE_QUERY_FILES_BATCH_SIZE", int, 200)
99103
NEPTUNE_QUERY_FILES_MAX_CONCURRENCY = EnvVariable[int]("NEPTUNE_QUERY_FILES_MAX_CONCURRENCY", int, 1)
100104
NEPTUNE_QUERY_FILES_TIMEOUT = EnvVariable[Optional[int]]("NEPTUNE_QUERY_FILES_TIMEOUT", _lift_optional(int), None)

src/neptune_query/internal/retrieval/search.py

Lines changed: 160 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,16 @@
4343
_Filter,
4444
)
4545
from ..logger import get_logger
46+
from ..retrieval import attribute_values as att_vals
4647
from ..retrieval import (
4748
retry,
4849
util,
4950
)
50-
from ..retrieval.attribute_types import map_attribute_type_python_to_backend
51+
from ..retrieval.attribute_types import (
52+
extract_value,
53+
map_attribute_type_backend_to_python,
54+
map_attribute_type_python_to_backend,
55+
)
5156

5257
logger = get_logger()
5358

@@ -77,6 +82,13 @@ def sys_id(self) -> identifiers.SysId: ...
7782
def label(self) -> str: ...
7883

7984

85+
@dataclass(frozen=True)
86+
class TableSearchEntry:
87+
sys_id: identifiers.SysId
88+
label: str
89+
values: list[att_vals.AttributeValue]
90+
91+
8092
@dataclass(frozen=True)
8193
class ExperimentSysAttrs:
8294
sys_id: identifiers.SysId
@@ -137,6 +149,65 @@ def __call__(
137149
) -> Generator[util.Page[T], None, None]: ...
138150

139151

152+
def _build_entries_search_params(
153+
*,
154+
attribute_projection: list[str],
155+
batch_size: int,
156+
container_type: ContainerType,
157+
filter_: Optional[_Filter],
158+
sort_by: _Attribute,
159+
sort_direction: Literal["asc", "desc"],
160+
) -> dict[str, Any]:
161+
params: dict[str, Any] = {
162+
"attributeFilters": [{"path": attribute_name} for attribute_name in attribute_projection],
163+
"pagination": {"limit": batch_size},
164+
"experimentLeader": container_type == ContainerType.EXPERIMENT,
165+
"sorting": {
166+
"dir": _map_direction(sort_direction),
167+
"sortBy": {"name": sort_by.name},
168+
},
169+
}
170+
if filter_ is not None:
171+
params["query"] = {"query": str(filter_)}
172+
if sort_by.aggregation is not None:
173+
params["sorting"]["aggregationMode"] = sort_by.aggregation
174+
if sort_by.type is not None:
175+
params["sorting"]["sortBy"]["type"] = map_attribute_type_python_to_backend(sort_by.type)
176+
177+
return params
178+
179+
180+
def _fetch_entries_with_projection(
181+
*,
182+
client: AuthenticatedClient,
183+
project_identifier: identifiers.ProjectIdentifier,
184+
attribute_projection: list[str],
185+
process_page: Callable[[ProtoLeaderboardEntriesSearchResultDTO], util.Page[T]],
186+
filter_: Optional[_Filter],
187+
sort_by: _Attribute,
188+
sort_direction: Literal["asc", "desc"],
189+
limit: Optional[int],
190+
batch_size: int,
191+
container_type: ContainerType,
192+
) -> Generator[util.Page[T], None, None]:
193+
params = _build_entries_search_params(
194+
attribute_projection=attribute_projection,
195+
batch_size=batch_size,
196+
container_type=container_type,
197+
filter_=filter_,
198+
sort_by=sort_by,
199+
sort_direction=sort_direction,
200+
)
201+
202+
return util.fetch_pages(
203+
client=client,
204+
fetch_page=ft.partial(_fetch_sys_attrs_page, project_identifier=project_identifier),
205+
process_page=process_page,
206+
make_new_page_params=ft.partial(_make_new_sys_attrs_page_params, batch_size=batch_size, limit=limit),
207+
initial_params=params,
208+
)
209+
210+
140211
def _create_fetch_sys_attrs(
141212
attribute_names: List[str],
142213
make_record: Callable[[dict[str, Any]], T],
@@ -152,28 +223,17 @@ def fetch_sys_attrs(
152223
batch_size: int = env.NEPTUNE_QUERY_SYS_ATTRS_BATCH_SIZE.get(),
153224
container_type: ContainerType = default_container_type,
154225
) -> Generator[util.Page[T], None, None]:
155-
params: dict[str, Any] = {
156-
"attributeFilters": [{"path": attribute_name} for attribute_name in attribute_names],
157-
"pagination": {"limit": batch_size},
158-
"experimentLeader": container_type == ContainerType.EXPERIMENT,
159-
"sorting": {
160-
"dir": _map_direction(sort_direction),
161-
"sortBy": {"name": sort_by.name},
162-
},
163-
}
164-
if filter_ is not None:
165-
params["query"] = {"query": str(filter_)}
166-
if sort_by.aggregation is not None:
167-
params["sorting"]["aggregationMode"] = sort_by.aggregation
168-
if sort_by.type is not None:
169-
params["sorting"]["sortBy"]["type"] = map_attribute_type_python_to_backend(sort_by.type)
170-
171-
return util.fetch_pages(
226+
return _fetch_entries_with_projection(
172227
client=client,
173-
fetch_page=ft.partial(_fetch_sys_attrs_page, project_identifier=project_identifier),
228+
project_identifier=project_identifier,
229+
attribute_projection=attribute_names,
174230
process_page=ft.partial(_process_sys_attrs_page, make_record=make_record),
175-
make_new_page_params=ft.partial(_make_new_sys_attrs_page_params, batch_size=batch_size, limit=limit),
176-
initial_params=params,
231+
filter_=filter_,
232+
sort_by=sort_by,
233+
sort_direction=sort_direction,
234+
limit=limit,
235+
batch_size=batch_size,
236+
container_type=container_type,
177237
)
178238

179239
return fetch_sys_attrs
@@ -215,6 +275,42 @@ def fetch_sys_id_labels(container_type: ContainerType) -> FetchSysAttrs[SysIdLab
215275
fetch_sys_ids = fetch_experiment_sys_ids
216276

217277

278+
def fetch_table_rows_exact_attributes(
279+
*,
280+
client: AuthenticatedClient,
281+
project_identifier: identifiers.ProjectIdentifier,
282+
filter_: Optional[_Filter],
283+
requested_attribute_names: set[str],
284+
sort_by: _Attribute,
285+
sort_direction: Literal["asc", "desc"],
286+
limit: Optional[int],
287+
container_type: ContainerType,
288+
) -> Generator[util.Page[TableSearchEntry], None, None]:
289+
batch_size = env.NEPTUNE_QUERY_SYS_ATTRS_BATCH_SIZE.get()
290+
291+
label_attribute_name = "sys/name" if container_type == ContainerType.EXPERIMENT else "sys/custom_run_id"
292+
projection_attribute_names = set(requested_attribute_names)
293+
projection_attribute_names.update({"sys/id", label_attribute_name})
294+
295+
yield from _fetch_entries_with_projection(
296+
client=client,
297+
project_identifier=project_identifier,
298+
attribute_projection=list(projection_attribute_names),
299+
process_page=ft.partial(
300+
_process_table_rows_exact_attributes_page,
301+
project_identifier=project_identifier,
302+
label_attribute_name=label_attribute_name,
303+
requested_attribute_names=requested_attribute_names,
304+
),
305+
filter_=filter_,
306+
sort_by=sort_by,
307+
sort_direction=sort_direction,
308+
limit=limit,
309+
batch_size=batch_size,
310+
container_type=container_type,
311+
)
312+
313+
218314
def _fetch_sys_attrs_page(
219315
client: AuthenticatedClient,
220316
params: dict[str, Any],
@@ -249,6 +345,49 @@ def _process_sys_attrs_page(
249345
return util.Page(items=items)
250346

251347

348+
def _process_table_rows_exact_attributes_page(
349+
data: ProtoLeaderboardEntriesSearchResultDTO,
350+
project_identifier: identifiers.ProjectIdentifier,
351+
label_attribute_name: str,
352+
requested_attribute_names: set[str],
353+
) -> util.Page[TableSearchEntry]:
354+
items: list[TableSearchEntry] = []
355+
356+
for entry in data.entries:
357+
attributes_by_name = {
358+
attr.name: attr
359+
for attr in entry.attributes
360+
if attr.name in ("sys/id", label_attribute_name) and attr.HasField("string_properties")
361+
}
362+
label = attributes_by_name[label_attribute_name].string_properties.value
363+
sys_id = identifiers.SysId(attributes_by_name["sys/id"].string_properties.value)
364+
run_identifier = identifiers.RunIdentifier(project_identifier=project_identifier, sys_id=sys_id)
365+
366+
values: list[att_vals.AttributeValue] = []
367+
for attr in entry.attributes:
368+
if attr.name not in requested_attribute_names:
369+
continue
370+
371+
item_value = extract_value(attr)
372+
if item_value is None:
373+
continue
374+
375+
values.append(
376+
att_vals.AttributeValue(
377+
attribute_definition=identifiers.AttributeDefinition(
378+
name=attr.name,
379+
type=map_attribute_type_backend_to_python(attr.type),
380+
),
381+
value=item_value,
382+
run_identifier=run_identifier,
383+
)
384+
)
385+
386+
items.append(TableSearchEntry(sys_id=sys_id, label=label, values=values))
387+
388+
return util.Page(items=items)
389+
390+
252391
def _make_new_sys_attrs_page_params(
253392
params: dict[str, Any],
254393
data: Optional[ProtoLeaderboardEntriesSearchResultDTO],

src/neptune_query/runs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def fetch_runs_table(
288288
sort_direction=sort_direction,
289289
limit=limit,
290290
type_suffix_in_column_names=type_suffix_in_column_names,
291+
exact_attribute_names=attributes if isinstance(attributes, list) else None,
291292
container_type=_search.ContainerType.RUN,
292293
)
293294

0 commit comments

Comments
 (0)