Skip to content

Commit 21fac4e

Browse files
committed
Fix output formatting logic
1 parent 25c75d1 commit 21fac4e

File tree

2 files changed

+42
-34
lines changed

2 files changed

+42
-34
lines changed

src/neptune_query/internal/output_format.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,14 +406,28 @@ def _pivot_and_reindex_df(
406406
index_column_name: str,
407407
timestamp_column_name: Optional[str],
408408
) -> pd.DataFrame:
409+
# Holds all existing (experiment, step) pairs
410+
# This is needed because pivot_table will add rows for all combinations of (experiment, step)
411+
# even if they don't exist in the original data, filling the rows with NaNs.
412+
observed_idx = pd.MultiIndex.from_frame(
413+
df[[index_column_name, "step"]]
414+
.astype(
415+
{
416+
index_column_name: "category",
417+
"step": "float64",
418+
}
419+
)
420+
.drop_duplicates()
421+
)
422+
409423
if df.empty and timestamp_column_name:
410424
# Handle empty DataFrame case to avoid pandas dtype errors
411425
df[timestamp_column_name] = pd.Series(dtype="datetime64[ns]")
412426

413427
if include_point_previews or timestamp_column_name:
414428
# if there are multiple value columns, don't specify them and rely on pandas to create the column multi-index
415429
df = df.pivot_table(
416-
index=[index_column_name, "step"], columns="path", aggfunc="first", observed=True, dropna=False
430+
index=[index_column_name, "step"], columns="path", aggfunc="first", observed=True, dropna=False, sort=False
417431
)
418432
else:
419433
# when there's only "value", define values explicitly, to make pandas generate a flat index
@@ -424,14 +438,19 @@ def _pivot_and_reindex_df(
424438
aggfunc="first",
425439
observed=True,
426440
dropna=False,
441+
sort=False,
427442
)
428443

429-
df = df.reset_index()
430-
df[index_column_name] = df[index_column_name].astype(str)
431-
df = df.sort_values(by=[index_column_name, "step"], ignore_index=True)
432-
df = df.set_index([index_column_name, "step"])
444+
# Include only observed (experiment, step) pairs
445+
df = df.reindex(index=observed_idx)
433446

434-
return df
447+
# Replace categorical codes in `index_column_name` with strings
448+
df.index = df.index.set_levels(
449+
df.index.get_level_values(index_column_name).unique().astype(str),
450+
level=index_column_name,
451+
)
452+
453+
return df.sort_index(level=[index_column_name, "step"])
435454

436455

437456
def _restore_path_column_names(

tests/unit/internal/test_output_format.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,6 @@ def test_create_metrics_dataframe_random_order():
911911
def test_create_empty_metrics_dataframe(
912912
type_suffix_in_column_names: bool, include_preview: bool, timestamp_column_name: str
913913
):
914-
# Given empty dataframe
915-
916914
# When
917915
df = create_metrics_dataframe(
918916
metrics_data={},
@@ -924,21 +922,17 @@ def test_create_empty_metrics_dataframe(
924922
)
925923

926924
# Then
925+
expected_df = (
926+
pd.DataFrame(data={"experiment": [], "step": []})
927+
.astype(dtype={"experiment": "object", "step": "float64"})
928+
.set_index(["experiment", "step"])
929+
)
930+
931+
# With previews or timestamps, MultiIndex columns are returned
927932
if include_preview or timestamp_column_name:
928-
expected_df = pd.DataFrame(
929-
index=pd.MultiIndex.from_tuples([], names=["experiment", "step"]),
930-
columns=pd.MultiIndex.from_tuples([], names=["path", "metric"]), # Create empty MultiIndex for columns
931-
)
932-
expected_df.columns.names = None, None
933-
else:
934-
expected_df = pd.DataFrame(
935-
{
936-
"experiment": [],
937-
"step": [],
938-
}
939-
).set_index(["experiment", "step"])
933+
expected_df.columns = pd.MultiIndex.from_tuples([], names=[None, None])
940934

941-
pd.testing.assert_frame_equal(df, expected_df, check_index_type=False)
935+
pd.testing.assert_frame_equal(df, expected_df, check_column_type=False)
942936

943937

944938
@pytest.mark.parametrize("timestamp_column_name", [None, "absolute"])
@@ -955,21 +949,16 @@ def test_create_empty_series_dataframe(timestamp_column_name: str):
955949
)
956950

957951
# Then
952+
expected_df = (
953+
pd.DataFrame(data={"experiment": [], "step": []})
954+
.astype(dtype={"experiment": "object", "step": "float64"})
955+
.set_index(["experiment", "step"])
956+
)
957+
958958
if timestamp_column_name:
959-
expected_df = pd.DataFrame(
960-
index=pd.MultiIndex.from_tuples([], names=["experiment", "step"]),
961-
columns=pd.MultiIndex.from_tuples([], names=["path", "metric"]), # Create empty MultiIndex for columns
962-
)
963-
expected_df.columns.names = None, None
964-
else:
965-
expected_df = pd.DataFrame(
966-
{
967-
"experiment": [],
968-
"step": [],
969-
}
970-
).set_index(["experiment", "step"])
959+
expected_df.columns = pd.MultiIndex.from_tuples([], names=[None, None])
971960

972-
pd.testing.assert_frame_equal(df, expected_df, check_index_type=False)
961+
pd.testing.assert_frame_equal(df, expected_df, check_index_type=False, check_column_type=False)
973962

974963

975964
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)