4343 _Filter ,
4444)
4545from ..logger import get_logger
46+ from ..retrieval import attribute_values as att_vals
4647from ..retrieval import (
4748 retry ,
4849 util ,
4950)
50- from ..retrieval .attribute_types import map_attribute_type_python_to_backend
51+ from ..retrieval .attribute_types import (
52+ extract_value ,
53+ map_attribute_type_backend_to_python ,
54+ map_attribute_type_python_to_backend ,
55+ )
5156
5257logger = get_logger ()
5358
@@ -77,6 +82,13 @@ def sys_id(self) -> identifiers.SysId: ...
7782 def label (self ) -> str : ...
7883
7984
85+ @dataclass (frozen = True )
86+ class TableSearchEntry :
87+ sys_id : identifiers .SysId
88+ label : str
89+ values : list [att_vals .AttributeValue ]
90+
91+
8092@dataclass (frozen = True )
8193class ExperimentSysAttrs :
8294 sys_id : identifiers .SysId
@@ -137,6 +149,65 @@ def __call__(
137149 ) -> Generator [util .Page [T ], None , None ]: ...
138150
139151
152+ def _build_entries_search_params (
153+ * ,
154+ attribute_projection : list [str ],
155+ batch_size : int ,
156+ container_type : ContainerType ,
157+ filter_ : Optional [_Filter ],
158+ sort_by : _Attribute ,
159+ sort_direction : Literal ["asc" , "desc" ],
160+ ) -> dict [str , Any ]:
161+ params : dict [str , Any ] = {
162+ "attributeFilters" : [{"path" : attribute_name } for attribute_name in attribute_projection ],
163+ "pagination" : {"limit" : batch_size },
164+ "experimentLeader" : container_type == ContainerType .EXPERIMENT ,
165+ "sorting" : {
166+ "dir" : _map_direction (sort_direction ),
167+ "sortBy" : {"name" : sort_by .name },
168+ },
169+ }
170+ if filter_ is not None :
171+ params ["query" ] = {"query" : str (filter_ )}
172+ if sort_by .aggregation is not None :
173+ params ["sorting" ]["aggregationMode" ] = sort_by .aggregation
174+ if sort_by .type is not None :
175+ params ["sorting" ]["sortBy" ]["type" ] = map_attribute_type_python_to_backend (sort_by .type )
176+
177+ return params
178+
179+
180+ def _fetch_entries_with_projection (
181+ * ,
182+ client : AuthenticatedClient ,
183+ project_identifier : identifiers .ProjectIdentifier ,
184+ attribute_projection : list [str ],
185+ process_page : Callable [[ProtoLeaderboardEntriesSearchResultDTO ], util .Page [T ]],
186+ filter_ : Optional [_Filter ],
187+ sort_by : _Attribute ,
188+ sort_direction : Literal ["asc" , "desc" ],
189+ limit : Optional [int ],
190+ batch_size : int ,
191+ container_type : ContainerType ,
192+ ) -> Generator [util .Page [T ], None , None ]:
193+ params = _build_entries_search_params (
194+ attribute_projection = attribute_projection ,
195+ batch_size = batch_size ,
196+ container_type = container_type ,
197+ filter_ = filter_ ,
198+ sort_by = sort_by ,
199+ sort_direction = sort_direction ,
200+ )
201+
202+ return util .fetch_pages (
203+ client = client ,
204+ fetch_page = ft .partial (_fetch_sys_attrs_page , project_identifier = project_identifier ),
205+ process_page = process_page ,
206+ make_new_page_params = ft .partial (_make_new_sys_attrs_page_params , batch_size = batch_size , limit = limit ),
207+ initial_params = params ,
208+ )
209+
210+
140211def _create_fetch_sys_attrs (
141212 attribute_names : List [str ],
142213 make_record : Callable [[dict [str , Any ]], T ],
@@ -152,28 +223,17 @@ def fetch_sys_attrs(
152223 batch_size : int = env .NEPTUNE_QUERY_SYS_ATTRS_BATCH_SIZE .get (),
153224 container_type : ContainerType = default_container_type ,
154225 ) -> Generator [util .Page [T ], None , None ]:
155- params : dict [str , Any ] = {
156- "attributeFilters" : [{"path" : attribute_name } for attribute_name in attribute_names ],
157- "pagination" : {"limit" : batch_size },
158- "experimentLeader" : container_type == ContainerType .EXPERIMENT ,
159- "sorting" : {
160- "dir" : _map_direction (sort_direction ),
161- "sortBy" : {"name" : sort_by .name },
162- },
163- }
164- if filter_ is not None :
165- params ["query" ] = {"query" : str (filter_ )}
166- if sort_by .aggregation is not None :
167- params ["sorting" ]["aggregationMode" ] = sort_by .aggregation
168- if sort_by .type is not None :
169- params ["sorting" ]["sortBy" ]["type" ] = map_attribute_type_python_to_backend (sort_by .type )
170-
171- return util .fetch_pages (
226+ return _fetch_entries_with_projection (
172227 client = client ,
173- fetch_page = ft .partial (_fetch_sys_attrs_page , project_identifier = project_identifier ),
228+ project_identifier = project_identifier ,
229+ attribute_projection = attribute_names ,
174230 process_page = ft .partial (_process_sys_attrs_page , make_record = make_record ),
175- make_new_page_params = ft .partial (_make_new_sys_attrs_page_params , batch_size = batch_size , limit = limit ),
176- initial_params = params ,
231+ filter_ = filter_ ,
232+ sort_by = sort_by ,
233+ sort_direction = sort_direction ,
234+ limit = limit ,
235+ batch_size = batch_size ,
236+ container_type = container_type ,
177237 )
178238
179239 return fetch_sys_attrs
@@ -215,6 +275,42 @@ def fetch_sys_id_labels(container_type: ContainerType) -> FetchSysAttrs[SysIdLab
215275fetch_sys_ids = fetch_experiment_sys_ids
216276
217277
278+ def fetch_table_rows_exact_attributes (
279+ * ,
280+ client : AuthenticatedClient ,
281+ project_identifier : identifiers .ProjectIdentifier ,
282+ filter_ : Optional [_Filter ],
283+ requested_attribute_names : set [str ],
284+ sort_by : _Attribute ,
285+ sort_direction : Literal ["asc" , "desc" ],
286+ limit : Optional [int ],
287+ container_type : ContainerType ,
288+ ) -> Generator [util .Page [TableSearchEntry ], None , None ]:
289+ batch_size = env .NEPTUNE_QUERY_SYS_ATTRS_BATCH_SIZE .get ()
290+
291+ label_attribute_name = "sys/name" if container_type == ContainerType .EXPERIMENT else "sys/custom_run_id"
292+ projection_attribute_names = set (requested_attribute_names )
293+ projection_attribute_names .update ({"sys/id" , label_attribute_name })
294+
295+ yield from _fetch_entries_with_projection (
296+ client = client ,
297+ project_identifier = project_identifier ,
298+ attribute_projection = list (projection_attribute_names ),
299+ process_page = ft .partial (
300+ _process_table_rows_exact_attributes_page ,
301+ project_identifier = project_identifier ,
302+ label_attribute_name = label_attribute_name ,
303+ requested_attribute_names = requested_attribute_names ,
304+ ),
305+ filter_ = filter_ ,
306+ sort_by = sort_by ,
307+ sort_direction = sort_direction ,
308+ limit = limit ,
309+ batch_size = batch_size ,
310+ container_type = container_type ,
311+ )
312+
313+
218314def _fetch_sys_attrs_page (
219315 client : AuthenticatedClient ,
220316 params : dict [str , Any ],
@@ -249,6 +345,49 @@ def _process_sys_attrs_page(
249345 return util .Page (items = items )
250346
251347
348+ def _process_table_rows_exact_attributes_page (
349+ data : ProtoLeaderboardEntriesSearchResultDTO ,
350+ project_identifier : identifiers .ProjectIdentifier ,
351+ label_attribute_name : str ,
352+ requested_attribute_names : set [str ],
353+ ) -> util .Page [TableSearchEntry ]:
354+ items : list [TableSearchEntry ] = []
355+
356+ for entry in data .entries :
357+ attributes_by_name = {
358+ attr .name : attr
359+ for attr in entry .attributes
360+ if attr .name in ("sys/id" , label_attribute_name ) and attr .HasField ("string_properties" )
361+ }
362+ label = attributes_by_name [label_attribute_name ].string_properties .value
363+ sys_id = identifiers .SysId (attributes_by_name ["sys/id" ].string_properties .value )
364+ run_identifier = identifiers .RunIdentifier (project_identifier = project_identifier , sys_id = sys_id )
365+
366+ values : list [att_vals .AttributeValue ] = []
367+ for attr in entry .attributes :
368+ if attr .name not in requested_attribute_names :
369+ continue
370+
371+ item_value = extract_value (attr )
372+ if item_value is None :
373+ continue
374+
375+ values .append (
376+ att_vals .AttributeValue (
377+ attribute_definition = identifiers .AttributeDefinition (
378+ name = attr .name ,
379+ type = map_attribute_type_backend_to_python (attr .type ),
380+ ),
381+ value = item_value ,
382+ run_identifier = run_identifier ,
383+ )
384+ )
385+
386+ items .append (TableSearchEntry (sys_id = sys_id , label = label , values = values ))
387+
388+ return util .Page (items = items )
389+
390+
252391def _make_new_sys_attrs_page_params (
253392 params : dict [str , Any ],
254393 data : Optional [ProtoLeaderboardEntriesSearchResultDTO ],
0 commit comments