Skip to content

Commit cab9c4e

Browse files
committed
PY-180 Change type inference in neptune-query for better UX (#406)
* PY-180 Change type inference in neptune-query for better UX 1. If the type cannot be inferred because the attribute doesn't exist in the project -> show a warning and infer as string 2. If the type cannot be inferred because the attribute has multiple types across the runs and experiments of the project -> raise an exception prompting the user to specify the type explicitly * don't try to be smart and limit the attribute search to only runs/experiments matching some filters, etc * Emit inference warnings for list_containers and list_attributes. * Test warning is emmitted for missing attributes
1 parent 0263a8f commit cab9c4e

File tree

8 files changed

+237
-243
lines changed

8 files changed

+237
-243
lines changed

src/neptune_query/internal/composition/fetch_metrics.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,10 @@ def fetch_metrics(
8484
client=client,
8585
project_identifier=project_identifier,
8686
filter_=filter_,
87-
executor=executor,
8887
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
89-
container_type=container_type,
9088
)
91-
if inference_result.is_run_domain_empty():
92-
return create_metrics_dataframe(
93-
metrics_data={},
94-
sys_id_label_mapping={},
95-
index_column_name="experiment" if container_type == ContainerType.EXPERIMENT else "run",
96-
timestamp_column_name="absolute_time" if include_time == "absolute" else None,
97-
include_point_previews=include_point_previews,
98-
type_suffix_in_column_names=type_suffix_in_column_names,
99-
)
10089
inferred_filter = inference_result.get_result_or_raise()
90+
inference_result.emit_warnings()
10191

10292
metrics_data, sys_id_to_label_mapping = _fetch_metrics(
10393
filter_=inferred_filter,

src/neptune_query/internal/composition/fetch_series.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,10 @@ def fetch_series(
8181
client=client,
8282
project_identifier=project_identifier,
8383
filter_=filter_,
84-
executor=executor,
8584
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
86-
container_type=container_type,
8785
)
88-
if inference_result.is_run_domain_empty():
89-
return create_series_dataframe(
90-
series_data={},
91-
project_identifier=project_identifier,
92-
sys_id_label_mapping={},
93-
index_column_name="experiment" if container_type == ContainerType.EXPERIMENT else "run",
94-
timestamp_column_name="absolute_time" if include_time == "absolute" else None,
95-
)
9686
inferred_filter = inference_result.get_result_or_raise()
87+
inference_result.emit_warnings()
9788

9889
sys_id_label_mapping: dict[identifiers.SysId, str] = {}
9990

src/neptune_query/internal/composition/fetch_table.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,38 +79,19 @@ def fetch_table(
7979
client=client,
8080
project_identifier=project_identifier,
8181
filter_=filter_,
82-
executor=executor,
8382
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
84-
container_type=container_type,
8583
)
86-
if inference_result.is_run_domain_empty():
87-
return output_format.convert_table_to_dataframe(
88-
table_data={},
89-
project_identifier=project_identifier,
90-
selected_aggregations={},
91-
type_suffix_in_column_names=type_suffix_in_column_names,
92-
index_column_name="experiment" if container_type == search.ContainerType.EXPERIMENT else "run",
93-
)
9484
filter_ = inference_result.get_result_or_raise()
85+
inference_result.emit_warnings()
9586

9687
sort_by_inference_result = type_inference.infer_attribute_types_in_sort_by(
9788
client=client,
9889
project_identifier=project_identifier,
99-
filter_=filter_,
10090
sort_by=sort_by,
101-
executor=executor,
10291
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
103-
container_type=container_type,
10492
)
105-
if sort_by_inference_result.is_run_domain_empty():
106-
return output_format.convert_table_to_dataframe(
107-
table_data={},
108-
project_identifier=project_identifier,
109-
selected_aggregations={},
110-
type_suffix_in_column_names=type_suffix_in_column_names,
111-
index_column_name="experiment" if container_type == search.ContainerType.EXPERIMENT else "run",
112-
)
113-
sort_by = sort_by_inference_result.get_result_or_raise()
93+
sort_by = sort_by_inference_result.result
94+
sort_by_inference_result.emit_warnings()
11495

11596
sys_id_label_mapping: dict[identifiers.SysId, str] = {}
11697
result_by_id: dict[identifiers.SysId, list[att_vals.AttributeValue]] = {}
@@ -184,6 +165,7 @@ def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
184165
index_column_name="experiment" if container_type == search.ContainerType.EXPERIMENT else "run",
185166
flatten_aggregations=flatten_aggregations,
186167
)
168+
187169
return dataframe
188170

189171

src/neptune_query/internal/composition/list_attributes.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,10 @@ def list_attributes(
6262
client,
6363
project_identifier,
6464
filter_,
65-
executor=executor,
6665
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
67-
container_type=container_type,
6866
)
69-
if inference_result.is_run_domain_empty():
70-
return []
7167
filter_ = inference_result.get_result_or_raise()
68+
inference_result.emit_warnings()
7269

7370
output = _components.fetch_attribute_definitions_complete(
7471
client=client,

src/neptune_query/internal/composition/list_containers.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,15 @@ def list_containers(
3737
validated_context = _context.validate_context(context or _context.get_context())
3838
client = _client.get_client(context=validated_context)
3939

40-
with (
41-
concurrency.create_thread_pool_executor() as executor,
42-
concurrency.create_thread_pool_executor() as fetch_attribute_definitions_executor,
43-
):
40+
with concurrency.create_thread_pool_executor() as fetch_attribute_definitions_executor:
4441
inference_result = type_inference.infer_attribute_types_in_filter(
4542
client=client,
4643
project_identifier=project_identifier,
4744
filter_=filter_,
48-
executor=executor,
4945
fetch_attribute_definitions_executor=fetch_attribute_definitions_executor,
5046
)
51-
if inference_result.is_run_domain_empty():
52-
return []
5347
filter_ = inference_result.get_result_or_raise()
48+
inference_result.emit_warnings()
5449

5550
sys_attr_pages = search.fetch_sys_id_labels(container_type)(client, project_identifier, filter_)
5651
return list(sorted(attrs.label for page in sys_attr_pages for attrs in page.items))

0 commit comments

Comments
 (0)