Skip to content

Commit 0aea085

Browse files
authored
fix: fix Index contains duplicate entries, cannot reshape (#15)
* fix: fix Index contains duplicate entries, cannot reshape * fix: expand the test * fix: add a test for metrics * fix: set observed=True --------- Co-authored-by: Michał Sośnicki <michal.sosnicki@neptune.ai>
1 parent 0f974ce commit 0aea085

File tree

2 files changed

+109
-4
lines changed

2 files changed

+109
-4
lines changed

src/neptune_query/internal/output_format.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,12 @@ def _pivot_and_reindex_df(
412412

413413
if include_point_previews or timestamp_column_name:
414414
# if there are multiple value columns, don't specify them and rely on pandas to create the column multi-index
415-
df = df.pivot(index=[index_column_name, "step"], columns="path")
415+
df = df.pivot_table(index=[index_column_name, "step"], columns="path", aggfunc="first", observed=True)
416416
else:
417417
# when there's only "value", define values explicitly, to make pandas generate a flat index
418-
df = df.pivot(index=[index_column_name, "step"], columns="path", values="value")
418+
df = df.pivot_table(
419+
index=[index_column_name, "step"], columns="path", values="value", aggfunc="first", observed=True
420+
)
419421

420422
df = df.reset_index()
421423
df[index_column_name] = df[index_column_name].astype(str)

tests/unit/internal/test_output_format.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,35 @@
55
timedelta,
66
timezone,
77
)
8+
from unittest.mock import patch
89

910
import numpy as np
1011
import pandas as pd
1112
import pytest
1213
from pandas._testing import assert_frame_equal
1314

15+
import neptune_query as npt
1416
from neptune_query.exceptions import ConflictingAttributeTypes
15-
from neptune_query.internal import identifiers
17+
from neptune_query.filters import AttributeFilter
18+
from neptune_query.internal import (
19+
context,
20+
identifiers,
21+
)
1622
from neptune_query.internal.identifiers import (
1723
AttributeDefinition,
1824
ProjectIdentifier,
1925
RunAttributeDefinition,
2026
RunIdentifier,
2127
SysId,
28+
SysName,
2229
)
2330
from neptune_query.internal.output_format import (
2431
convert_table_to_dataframe,
2532
create_files_dataframe,
2633
create_metrics_dataframe,
2734
create_series_dataframe,
2835
)
36+
from neptune_query.internal.retrieval import util
2937
from neptune_query.internal.retrieval.attribute_types import File as IFile
3038
from neptune_query.internal.retrieval.attribute_types import (
3139
FileSeriesAggregations,
@@ -38,7 +46,10 @@
3846
)
3947
from neptune_query.internal.retrieval.attribute_values import AttributeValue
4048
from neptune_query.internal.retrieval.metrics import FloatPointValue
41-
from neptune_query.internal.retrieval.search import ContainerType
49+
from neptune_query.internal.retrieval.search import (
50+
ContainerType,
51+
ExperimentSysAttrs,
52+
)
4253
from neptune_query.internal.retrieval.series import SeriesValue
4354
from neptune_query.types import File as OFile
4455
from neptune_query.types import Histogram as OHistogram
@@ -1195,3 +1206,95 @@ def test_create_files_dataframe_index_name_attribute_conflict():
11951206
expected_df.columns.names = ["attribute"]
11961207
expected_df.index.names = [index_column_name, "step"]
11971208
assert_frame_equal(dataframe, expected_df)
1209+
1210+
1211+
@pytest.mark.parametrize("duplicate_variant", [(2, 1, 1), (1, 2, 1), (1, 1, 2), (2, 2, 2)])
1212+
@pytest.mark.parametrize("include_time", [None, "absolute"])
1213+
def test_fetch_series_duplicate_values(duplicate_variant, include_time):
1214+
# given
1215+
project = ProjectIdentifier("project")
1216+
context.set_api_token("irrelevant")
1217+
experiments = [ExperimentSysAttrs(sys_id=SysId("sysid0"), sys_name=SysName("irrelevant"))]
1218+
attributes = [AttributeDefinition(name="attribute0", type="irrelevant")]
1219+
run_attribute_definitions = [
1220+
RunAttributeDefinition(
1221+
run_identifier=RunIdentifier(project_identifier=project, sys_id=experiments[0].sys_id),
1222+
attribute_definition=attributes[0],
1223+
)
1224+
]
1225+
1226+
duped_values, duped_attributes, duped_pages = duplicate_variant
1227+
series_values = [
1228+
(
1229+
run_attribute_definitions[0],
1230+
[SeriesValue(step=i, value=f"{i}", timestamp_millis=i) for i in range(100)] * duped_values,
1231+
)
1232+
] * duped_attributes
1233+
1234+
# when
1235+
with (
1236+
patch("neptune_query.internal.composition.fetch_series.get_client") as get_client,
1237+
patch("neptune_query.internal.retrieval.search.fetch_experiment_sys_attrs") as fetch_experiment_sys_attrs,
1238+
patch(
1239+
"neptune_query.internal.retrieval.attribute_definitions.fetch_attribute_definitions_single_filter"
1240+
) as fetch_attribute_definitions_single_filter,
1241+
patch("neptune_query.internal.retrieval.series.fetch_series_values") as fetch_series_values,
1242+
):
1243+
get_client.return_value = None
1244+
fetch_experiment_sys_attrs.return_value = iter([util.Page(experiments)])
1245+
fetch_attribute_definitions_single_filter.side_effect = lambda **kwargs: iter([util.Page(attributes)])
1246+
fetch_series_values.return_value = iter([util.Page(series_values)] * duped_pages)
1247+
1248+
df = npt.fetch_series(
1249+
project=project,
1250+
experiments="ignored",
1251+
attributes=AttributeFilter(name="ignored"),
1252+
include_time=include_time,
1253+
)
1254+
1255+
# then
1256+
assert df.shape == (100, 1 if not include_time else 2)
1257+
1258+
1259+
@pytest.mark.parametrize("include_time", [None, "absolute"])
1260+
def test_fetch_metrics_duplicate_values(include_time):
1261+
# given
1262+
project = ProjectIdentifier("project")
1263+
context.set_api_token("irrelevant")
1264+
experiments = [ExperimentSysAttrs(sys_id=SysId("sysid0"), sys_name=SysName("irrelevant"))]
1265+
attributes = [AttributeDefinition(name="attribute0", type="float_series")]
1266+
run_attribute_definitions = [
1267+
RunAttributeDefinition(
1268+
run_identifier=RunIdentifier(project_identifier=project, sys_id=experiments[0].sys_id),
1269+
attribute_definition=attributes[0],
1270+
)
1271+
]
1272+
series_values = {
1273+
run_attribute_definitions[0]: [SeriesValue(step=i, value=float(i), timestamp_millis=i) for i in range(100)] * 2
1274+
}
1275+
1276+
# when
1277+
with (
1278+
patch("neptune_query.internal.composition.fetch_metrics.get_client") as get_client,
1279+
patch("neptune_query.internal.retrieval.search.fetch_experiment_sys_attrs") as fetch_experiment_sys_attrs,
1280+
patch(
1281+
"neptune_query.internal.retrieval.attribute_definitions.fetch_attribute_definitions_single_filter"
1282+
) as fetch_attribute_definitions_single_filter,
1283+
patch(
1284+
"neptune_query.internal.composition.fetch_metrics.fetch_multiple_series_values"
1285+
) as fetch_multiple_series_values,
1286+
):
1287+
get_client.return_value = None
1288+
fetch_experiment_sys_attrs.return_value = iter([util.Page(experiments)])
1289+
fetch_attribute_definitions_single_filter.side_effect = lambda **kwargs: iter([util.Page(attributes)])
1290+
fetch_multiple_series_values.return_value = series_values
1291+
1292+
df = npt.fetch_metrics(
1293+
project=project,
1294+
experiments="ignored",
1295+
attributes=AttributeFilter(name="ignored"),
1296+
include_time=include_time,
1297+
)
1298+
1299+
# then
1300+
assert df.shape == (100, 1 if not include_time else 2)

0 commit comments

Comments
 (0)