Skip to content

Commit 7a0859a

Browse files
authored
fix: fixes from fuzzy tests (#72)
* fix: fix problems found in fuzzy tests * fix: add more test cases * fix: assert column order in output_format tests * fix: assert column order in fetch_metric tests * fix: assert column order in fetch_series tests --------- Co-authored-by: Michał Sośnicki <michal.sosnicki@neptune.ai>
1 parent cad420d commit 7a0859a

File tree

4 files changed

+286
-38
lines changed

4 files changed

+286
-38
lines changed

src/neptune_query/internal/output_format.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def generate_categorized_rows() -> Generator[Tuple, None, None]:
291291
]
292292

293293
if timestamp_column_name:
294-
types.append((timestamp_column_name, "uint64"))
294+
types.append((timestamp_column_name, "int64"))
295295

296296
if include_point_previews:
297297
types.append(("is_preview", "bool"))
@@ -304,7 +304,7 @@ def generate_categorized_rows() -> Generator[Tuple, None, None]:
304304
if timestamp_column_name:
305305
df[timestamp_column_name] = pd.to_datetime(df[timestamp_column_name], unit="ms", origin="unix", utc=True)
306306

307-
df = _pivot_df(df, include_point_previews, index_column_name, timestamp_column_name)
307+
df = _pivot_df(df, index_column_name, timestamp_column_name, extra_value_columns=types[4:])
308308
df = _restore_labels_in_index(df, index_column_name, label_mapping)
309309
df = _restore_path_column_names(df, path_mapping, "float_series" if type_suffix_in_column_names else None)
310310
df = _sort_index_and_columns(df, index_column_name)
@@ -383,7 +383,7 @@ def generate_categorized_rows() -> Generator[Tuple, None, None]:
383383
("value", "object"),
384384
]
385385
if timestamp_column_name:
386-
types.append((timestamp_column_name, "uint64"))
386+
types.append((timestamp_column_name, "int64"))
387387

388388
df = pd.DataFrame(
389389
np.fromiter(generate_categorized_rows(), dtype=types),
@@ -392,7 +392,7 @@ def generate_categorized_rows() -> Generator[Tuple, None, None]:
392392
if timestamp_column_name:
393393
df[timestamp_column_name] = pd.to_datetime(df[timestamp_column_name], unit="ms", origin="unix", utc=True)
394394

395-
df = _pivot_df(df, False, index_column_name, timestamp_column_name)
395+
df = _pivot_df(df, index_column_name, timestamp_column_name, extra_value_columns=types[4:])
396396
df = _restore_labels_in_index(df, index_column_name, label_mapping)
397397
df = _restore_path_column_names(df, path_mapping, None)
398398
df = _sort_index_and_columns(df, index_column_name)
@@ -468,24 +468,31 @@ def generate_categorized_rows() -> Generator[Tuple, None, None]:
468468
columns=[container_column_name, "path"],
469469
values=["x", "y"],
470470
observed=True,
471-
dropna=False,
471+
dropna=True,
472472
sort=False,
473473
)
474474

475475
df = _restore_labels_in_columns(df, container_column_name, label_mapping)
476476
df = _restore_path_column_names(df, path_mapping, None)
477477

478-
# Clear out any columns that were not requested, but got added because of dropna=False
479-
desired_columns = [
480-
(
481-
dim,
482-
sys_id_label_mapping[run_attr_definition.run_identifier.sys_id],
483-
run_attr_definition.attribute_definition.name,
478+
# Add back any columns that were removed because they were all NaN
479+
if buckets_data:
480+
desired_columns = pd.MultiIndex.from_tuples(
481+
[
482+
(
483+
dim,
484+
sys_id_label_mapping[run_attr_definition.run_identifier.sys_id],
485+
run_attr_definition.attribute_definition.name,
486+
)
487+
for run_attr_definition in buckets_data.keys()
488+
for dim in ("x", "y")
489+
],
490+
names=["bucket", container_column_name, "metric"],
484491
)
485-
for run_attr_definition in buckets_data.keys()
486-
for dim in ("x", "y")
487-
]
488-
df = df.filter(desired_columns, axis="columns")
492+
df = df.reindex(columns=desired_columns)
493+
else:
494+
# Handle empty case - create expected column structure
495+
df.columns = pd.MultiIndex.from_product([["x", "y"], [], []], names=["bucket", container_column_name, "metric"])
489496

490497
df = df.reorder_levels([1, 2, 0], axis="columns")
491498
df = df.sort_index(axis="columns", level=[0, 1])
@@ -500,8 +507,9 @@ def generate_categorized_rows() -> Generator[Tuple, None, None]:
500507

501508
def _collapse_open_buckets(df: pd.DataFrame) -> pd.DataFrame:
502509
"""
503-
1st returned bucket is always (-inf, first_point], which we merge with the 2nd bucket (first_point, end],
510+
1st returned bucket is (-inf, first_point], which we merge with the 2nd bucket (first_point, end],
504511
resulting in a new bucket [first_point, end].
512+
If there's only one bucket, it should have form (first_point, inf). We transform it to [first_point, first_point].
505513
"""
506514
df.index = df.index.astype(object) # IntervalIndex cannot mix Intervals closed from different sides
507515

@@ -541,13 +549,12 @@ def _collapse_open_buckets(df: pd.DataFrame) -> pd.DataFrame:
541549

542550
def _pivot_df(
543551
df: pd.DataFrame,
544-
include_point_previews: bool,
545552
index_column_name: str,
546553
timestamp_column_name: Optional[str],
554+
extra_value_columns: list[tuple[str, str]],
547555
) -> pd.DataFrame:
548556
# Holds all existing (experiment, step) pairs
549-
# This is needed because pivot_table will add rows for all combinations of (experiment, step)
550-
# even if they don't exist in the original data, filling the rows with NaNs.
557+
# This is needed because pivot_table will remove rows if they are all NaN
551558
observed_idx = pd.MultiIndex.from_frame(
552559
df[[index_column_name, "step"]]
553560
.astype(
@@ -563,12 +570,21 @@ def _pivot_df(
563570
# Handle empty DataFrame case to avoid pandas dtype errors
564571
df[timestamp_column_name] = pd.Series(dtype="datetime64[ns]")
565572

566-
if include_point_previews or timestamp_column_name:
573+
if extra_value_columns:
574+
# Holds all existing columns
575+
# This is needed because pivot_table will remove columns if they are all NaN
576+
value_columns = ["value"] + [col[0] for col in extra_value_columns]
577+
observed_columns = pd.MultiIndex.from_tuples(
578+
[(value, path) for path in df["path"].unique() for value in value_columns], names=[None, "path"]
579+
)
580+
567581
# if there are multiple value columns, don't specify them and rely on pandas to create the column multi-index
568582
df = df.pivot_table(
569583
index=[index_column_name, "step"], columns="path", aggfunc="first", observed=True, dropna=True, sort=False
570584
)
571585
else:
586+
observed_columns = df["path"].unique()
587+
572588
# when there's only "value", define values explicitly, to make pandas generate a flat index
573589
df = df.pivot_table(
574590
index=[index_column_name, "step"],
@@ -580,8 +596,8 @@ def _pivot_df(
580596
sort=False,
581597
)
582598

583-
# Include only observed (experiment, step) pairs
584-
return df.reindex(observed_idx)
599+
# Add back any columns that were removed because they were all NaN
600+
return df.reindex(index=observed_idx, columns=observed_columns)
585601

586602

587603
def _restore_labels_in_index(
@@ -637,7 +653,7 @@ def _sort_index_and_columns(df: pd.DataFrame, index_column_name: str) -> pd.Data
637653
if isinstance(df.columns, pd.MultiIndex):
638654
df.columns.names = (None, None)
639655
df = df.swaplevel(axis="columns")
640-
df = df.sort_index(axis="columns", level=0)
656+
df = df.sort_index(axis="columns", level=0, kind="stable", sort_remaining=False)
641657
else:
642658
df.columns.name = None
643659
df = df.sort_index(axis="columns")

tests/e2e/v1/test_fetch_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def create_expected_data(
156156

157157
sorted_columns = list(sorted(columns))
158158
if include_time == "absolute":
159-
absolute_columns = [[(c, "absolute_time"), (c, "value")] for c in sorted_columns]
159+
absolute_columns = [[(c, "value"), (c, "absolute_time")] for c in sorted_columns]
160160
return df, list(chain.from_iterable(absolute_columns)), filtered_experiments
161161
else:
162162
return df, sorted_columns, filtered_experiments

tests/e2e/v1/test_fetch_series.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def create_expected_data_string_series(
104104

105105
sorted_columns = list(sorted(columns))
106106
if include_time == "absolute":
107-
absolute_columns = [[(c, "absolute_time"), (c, "value")] for c in sorted_columns]
107+
absolute_columns = [[(c, "value"), (c, "absolute_time")] for c in sorted_columns]
108108
return df, list(it.chain.from_iterable(absolute_columns)), filtered_exps
109109
else:
110110
return df, sorted_columns, filtered_exps
@@ -332,7 +332,7 @@ def create_expected_data_histogram_series(
332332

333333
sorted_columns = list(sorted(columns))
334334
if include_time == "absolute":
335-
absolute_columns = [[(c, "absolute_time"), (c, "value")] for c in sorted_columns]
335+
absolute_columns = [[(c, "value"), (c, "absolute_time")] for c in sorted_columns]
336336
return df, list(it.chain.from_iterable(absolute_columns)), filtered_exps
337337
else:
338338
return df, sorted_columns, filtered_exps
@@ -552,7 +552,7 @@ def create_expected_data_file_series(
552552

553553
sorted_columns = list(sorted(columns))
554554
if include_time == "absolute":
555-
absolute_columns = [[(c, "absolute_time"), (c, "value")] for c in sorted_columns]
555+
absolute_columns = [[(c, "value"), (c, "absolute_time")] for c in sorted_columns]
556556
return df, list(it.chain.from_iterable(absolute_columns)), filtered_exps
557557
else:
558558
return df, sorted_columns, filtered_exps

0 commit comments

Comments
 (0)