Skip to content

Commit 1bd2f7e

Browse files
authored
fix: lineage to the root respects experiment (#14)
* fix: lineage to the root respects experiment * Add tests * Fix unit test * Fix import * Add fix for fetch_series + test * Import fix * Improve series test * Improve metrics test * Fix tests
1 parent 0aea085 commit 1bd2f7e

File tree

10 files changed

+208
-22
lines changed

10 files changed

+208
-22
lines changed

src/neptune_query/internal/composition/fetch_metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
171171
run_attribute_definitions=run_attribute_definitions_split,
172172
include_inherited=lineage_to_the_root,
173173
include_preview=include_point_previews,
174+
container_type=container_type,
174175
step_range=step_range,
175176
tail_limit=tail_limit,
176177
)

src/neptune_query/internal/composition/fetch_series.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def go_fetch_sys_attrs() -> Generator[list[identifiers.SysId], None, None]:
127127
client=client,
128128
run_attribute_definitions=run_attribute_definitions_split,
129129
include_inherited=lineage_to_the_root,
130+
container_type=container_type,
130131
step_range=step_range,
131132
tail_limit=tail_limit,
132133
),

src/neptune_query/internal/retrieval/metrics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
retry,
3434
util,
3535
)
36+
from .search import ContainerType
3637

3738
logger = logging.getLogger(__name__)
3839

@@ -53,6 +54,7 @@ def fetch_multiple_series_values(
5354
client: AuthenticatedClient,
5455
run_attribute_definitions: list[identifiers.RunAttributeDefinition],
5556
include_inherited: bool,
57+
container_type: ContainerType,
5658
include_preview: bool,
5759
step_range: tuple[Union[float, None], Union[float, None]] = (None, None),
5860
tail_limit: Optional[int] = None,
@@ -81,6 +83,7 @@ def fetch_multiple_series_values(
8183
},
8284
"attribute": run_attribute.attribute_definition.name,
8385
"lineage": "FULL" if include_inherited else "NONE",
86+
"lineageEntityType": "EXPERIMENT" if container_type == ContainerType.EXPERIMENT else "RUN",
8487
"includePreview": include_preview,
8588
},
8689
}

src/neptune_query/internal/retrieval/series.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
File,
4444
Histogram,
4545
)
46+
from .search import ContainerType
4647

4748
SeriesValue = NamedTuple("SeriesValue", [("step", float), ("value", Any), ("timestamp_millis", float)])
4849

@@ -51,6 +52,7 @@ def fetch_series_values(
5152
client: AuthenticatedClient,
5253
run_attribute_definitions: Iterable[RunAttributeDefinition],
5354
include_inherited: bool,
55+
container_type: ContainerType,
5456
step_range: Tuple[Union[float, None], Union[float, None]] = (None, None),
5557
tail_limit: Optional[int] = None,
5658
) -> Generator[util.Page[tuple[RunAttributeDefinition, list[SeriesValue]]], None, None]:
@@ -75,6 +77,7 @@ def fetch_series_values(
7577
},
7678
"attribute": run_definition.attribute_definition.name,
7779
"lineage": "FULL" if include_inherited else "NONE",
80+
"lineageEntityType": "EXPERIMENT" if container_type == ContainerType.EXPERIMENT else "RUN",
7881
},
7982
}
8083
for request_id, run_definition in request_id_to_run_attr_definition.items()

tests/e2e/internal/retrieval/test_series.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from neptune_query.internal.identifiers import AttributeDefinition
6+
from neptune_query.internal.retrieval.search import ContainerType
67
from neptune_query.internal.retrieval.series import (
78
RunAttributeDefinition,
89
fetch_series_values,
@@ -27,6 +28,7 @@ def test_fetch_series_values_does_not_exist(client, project, experiment_identifi
2728
client,
2829
[run_definition],
2930
include_inherited=False,
31+
container_type=ContainerType.EXPERIMENT,
3032
)
3133
)
3234

@@ -59,6 +61,7 @@ def test_fetch_series_values_single_series(
5961
client,
6062
[run_definition],
6163
include_inherited=False,
64+
container_type=ContainerType.EXPERIMENT,
6265
)
6366
)
6467

@@ -110,7 +113,13 @@ def test_fetch_series_values_single_series_stop_range(
110113

111114
# when
112115
series = extract_pages(
113-
fetch_series_values(client, [run_definition], include_inherited=False, step_range=step_range)
116+
fetch_series_values(
117+
client,
118+
[run_definition],
119+
include_inherited=False,
120+
container_type=ContainerType.EXPERIMENT,
121+
step_range=step_range,
122+
)
114123
)
115124

116125
# then
@@ -149,7 +158,13 @@ def test_fetch_series_values_single_series_tail_limit(
149158

150159
# when
151160
series = extract_pages(
152-
fetch_series_values(client, [run_definition], include_inherited=False, tail_limit=tail_limit)
161+
fetch_series_values(
162+
client,
163+
[run_definition],
164+
include_inherited=False,
165+
container_type=ContainerType.EXPERIMENT,
166+
tail_limit=tail_limit,
167+
)
153168
)
154169

155170
# then

tests/e2e/internal/test_split.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
fetch_attribute_values,
1818
)
1919
from neptune_query.internal.retrieval.metrics import fetch_multiple_series_values
20+
from neptune_query.internal.retrieval.search import ContainerType
2021
from neptune_query.internal.retrieval.series import (
2122
SeriesValue,
2223
fetch_series_values,
@@ -186,7 +187,12 @@ def test_fetch_string_series_values_retrieval(client, project, experiment_identi
186187
try:
187188
result = extract_pages(
188189
fetch_series_values(
189-
client, attribute_definitions, include_inherited=True, step_range=(None, None), tail_limit=None
190+
client,
191+
attribute_definitions,
192+
include_inherited=True,
193+
container_type=ContainerType.EXPERIMENT,
194+
step_range=(None, None),
195+
tail_limit=None,
190196
)
191197
)
192198
except (NeptuneRetryError, NeptuneUnexpectedResponseError) as e:
@@ -265,6 +271,7 @@ def test_fetch_float_series_values_retrieval(client, project, experiment_identif
265271
client,
266272
attribute_definitions,
267273
include_inherited=True,
274+
container_type=ContainerType.EXPERIMENT,
268275
include_preview=False,
269276
step_range=(None, None),
270277
tail_limit=None,

tests/e2e/v1/generator.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import concurrent.futures
2-
from dataclasses import dataclass
2+
from dataclasses import (
3+
dataclass,
4+
field,
5+
)
36
from datetime import (
47
datetime,
58
timedelta,
@@ -23,15 +26,16 @@
2326
class GeneratedRun:
2427
custom_run_id: str
2528
experiment_name: str
26-
fork_run_id: Union[str, None]
27-
fork_level: Optional[int]
28-
fork_point: Optional[int]
29-
configs: dict[AttributeName, Union[float, bool, int, str, datetime, list, set, tuple]]
30-
metrics: dict[AttributeName, dict[Step, Value]]
31-
tags: list[str]
29+
fork_run_id: Union[str, None] = None
30+
fork_level: Optional[int] = None
31+
fork_point: Optional[int] = None
32+
configs: dict[AttributeName, Union[float, bool, int, str, datetime, list, set, tuple]] = field(default_factory=dict)
33+
metrics: dict[AttributeName, dict[Step, Value]] = field(default_factory=dict)
34+
string_series: dict[AttributeName, dict[Step, str]] = field(default_factory=dict)
35+
tags: list[str] = field(default_factory=list)
3236

3337
def attributes(self):
34-
return set().union(self.configs.keys(), self.metrics.keys())
38+
return set().union(self.configs.keys(), self.metrics.keys(), self.string_series.keys())
3539

3640
def metrics_values(self, name: AttributeName) -> list[tuple[Step, Value]]:
3741
return list(self.metrics[name].items())
@@ -51,9 +55,6 @@ def metrics_values(self, name: AttributeName) -> list[tuple[Step, Value]]:
5155
GeneratedRun(
5256
custom_run_id="linear_history_root",
5357
experiment_name=LINEAR_TREE_EXP_NAME,
54-
fork_level=None,
55-
fork_point=None,
56-
fork_run_id=None,
5758
tags=["linear_root", "linear"],
5859
configs={
5960
"int-value": 1,
@@ -127,9 +128,6 @@ def metrics_values(self, name: AttributeName) -> list[tuple[Step, Value]]:
127128
GeneratedRun(
128129
custom_run_id="forked_history_root",
129130
experiment_name=FORKED_TREE_EXP_NAME,
130-
fork_level=None,
131-
fork_point=None,
132-
fork_run_id=None,
133131
tags=["forked_history_root", "forked_history"],
134132
configs={
135133
"int-value": 1,
@@ -186,7 +184,54 @@ def metrics_values(self, name: AttributeName) -> list[tuple[Step, Value]]:
186184
),
187185
]
188186

189-
ALL_STATIC_RUNS = LINEAR_HISTORY_TREE + FORKED_HISTORY_TREE
187+
# Tree structure:
188+
#
189+
# multi_experiment_history:
190+
# root (level: None, experiment: exp_with_multi_experiment_history_1)
191+
# └── fork1 (level: 1, fork_point: 4, experiment: exp_with_multi_experiment_history_2)
192+
# └── fork2 (level: 2, fork_point: 8, experiment: exp_with_multi_experiment_history_2)
193+
MULT_EXPERIMENT_HISTORY_EXP_1 = "exp_with_multi_experiment_history_1"
194+
MULT_EXPERIMENT_HISTORY_EXP_2 = "exp_with_multi_experiment_history_2"
195+
MULTI_EXPERIMENT_HISTORY = [
196+
GeneratedRun(
197+
custom_run_id="mult_exp_history_run_1",
198+
experiment_name=MULT_EXPERIMENT_HISTORY_EXP_1,
199+
metrics={
200+
"metrics/m1": {step: step * 0.1 for step in range(0, 5)},
201+
},
202+
string_series={
203+
"string_series/s1": {step: f"val_run1_{step}" for step in range(0, 5)},
204+
},
205+
),
206+
GeneratedRun(
207+
custom_run_id="mult_exp_history_run_2",
208+
experiment_name=MULT_EXPERIMENT_HISTORY_EXP_2,
209+
fork_level=1,
210+
fork_point=4,
211+
fork_run_id="mult_exp_history_run_1",
212+
metrics={
213+
"metrics/m1": {step: step * 0.2 for step in range(5, 9)},
214+
},
215+
string_series={
216+
"string_series/s1": {step: f"val_run2_{step}" for step in range(5, 9)},
217+
},
218+
),
219+
GeneratedRun(
220+
custom_run_id="mult_exp_history_run_3",
221+
experiment_name=MULT_EXPERIMENT_HISTORY_EXP_2,
222+
fork_level=2,
223+
fork_point=8,
224+
fork_run_id="mult_exp_history_run_2",
225+
metrics={
226+
"metrics/m1": {step: step * 0.3 for step in range(9, 12)},
227+
},
228+
string_series={
229+
"string_series/s1": {step: f"val_run3_{step}" for step in range(9, 12)},
230+
},
231+
),
232+
]
233+
234+
ALL_STATIC_RUNS = LINEAR_HISTORY_TREE + FORKED_HISTORY_TREE + MULTI_EXPERIMENT_HISTORY
190235
RUN_BY_ID = {run.custom_run_id: run for run in ALL_STATIC_RUNS}
191236

192237

@@ -209,6 +254,10 @@ def log_run(generated: GeneratedRun, api_token: str, e2e_alpha_project: str):
209254
for step, value in metric_values.items():
210255
run.log_metrics(step=step, data={metric_name: value}, timestamp=timestamp_for_step(step))
211256

257+
for string_series_name, string_series_values in generated.string_series.items():
258+
for step, value in string_series_values.items():
259+
run.log_string_series(step=step, data={string_series_name: value}, timestamp=timestamp_for_step(step))
260+
212261

213262
def log_runs(api_token: str, e2e_alpha_project: str, runs: list[GeneratedRun]):
214263
max_level = max(run.fork_level or 0 for run in runs)

tests/e2e/v1/test_fetch_metrics.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@
3434
TEST_DATA,
3535
ExperimentData,
3636
)
37+
from tests.e2e.v1.generator import (
38+
MULT_EXPERIMENT_HISTORY_EXP_2,
39+
timestamp_for_step,
40+
)
41+
42+
43+
def _to_run_attribute_definition(project, run, metric_name):
44+
return RunAttributeDefinition(
45+
RunIdentifier(ProjectIdentifier(project), SysId(run)),
46+
AttributeDefinition(metric_name, "float_series"),
47+
)
48+
49+
50+
def _to_float_point_value(step, value):
51+
return int(timestamp_for_step(step).timestamp() * 1000), step, value, False, 1.0
3752

3853

3954
def create_expected_data(
@@ -295,3 +310,39 @@ def test__fetch_metrics_unique__output_format_variants(
295310
assert result.columns.tolist() == columns
296311
assert result.index.names == ["experiment", "step"]
297312
assert {t[0] for t in result.index.tolist()} == filtred_exps
313+
314+
315+
@pytest.mark.parametrize(
316+
"lineage_to_the_root,expected_values",
317+
[
318+
(
319+
True,
320+
[(step, step * 0.1) for step in range(0, 5)]
321+
+ [(step, step * 0.2) for step in range(5, 9)]
322+
+ [(step, step * 0.3) for step in range(9, 12)],
323+
),
324+
(False, [(step, step * 0.2) for step in range(5, 9)] + [(step, step * 0.3) for step in range(9, 12)]),
325+
],
326+
)
327+
def test__fetch_metrics__lineage(new_project_id, lineage_to_the_root, expected_values):
328+
df = fetch_metrics(
329+
project=new_project_id,
330+
experiments=[MULT_EXPERIMENT_HISTORY_EXP_2],
331+
attributes=r"metrics/m1",
332+
lineage_to_the_root=lineage_to_the_root,
333+
)
334+
335+
expected = create_metrics_dataframe(
336+
metrics_data={
337+
_to_run_attribute_definition(new_project_id, MULT_EXPERIMENT_HISTORY_EXP_2, "metrics/m1"): [
338+
_to_float_point_value(step, value) for step, value in expected_values
339+
]
340+
},
341+
sys_id_label_mapping={SysId(MULT_EXPERIMENT_HISTORY_EXP_2): MULT_EXPERIMENT_HISTORY_EXP_2},
342+
type_suffix_in_column_names=False,
343+
include_point_previews=False,
344+
timestamp_column_name=None,
345+
index_column_name="experiment",
346+
)
347+
348+
pd.testing.assert_frame_equal(df, expected)

0 commit comments

Comments
 (0)