|
5 | 5 | timedelta, |
6 | 6 | timezone, |
7 | 7 | ) |
| 8 | +from unittest.mock import patch |
8 | 9 |
|
9 | 10 | import numpy as np |
10 | 11 | import pandas as pd |
11 | 12 | import pytest |
12 | 13 | from pandas._testing import assert_frame_equal |
13 | 14 |
|
| 15 | +import neptune_query as npt |
14 | 16 | from neptune_query.exceptions import ConflictingAttributeTypes |
15 | | -from neptune_query.internal import identifiers |
| 17 | +from neptune_query.filters import AttributeFilter |
| 18 | +from neptune_query.internal import ( |
| 19 | + context, |
| 20 | + identifiers, |
| 21 | +) |
16 | 22 | from neptune_query.internal.identifiers import ( |
17 | 23 | AttributeDefinition, |
18 | 24 | ProjectIdentifier, |
19 | 25 | RunAttributeDefinition, |
20 | 26 | RunIdentifier, |
21 | 27 | SysId, |
| 28 | + SysName, |
22 | 29 | ) |
23 | 30 | from neptune_query.internal.output_format import ( |
24 | 31 | convert_table_to_dataframe, |
25 | 32 | create_files_dataframe, |
26 | 33 | create_metrics_dataframe, |
27 | 34 | create_series_dataframe, |
28 | 35 | ) |
| 36 | +from neptune_query.internal.retrieval import util |
29 | 37 | from neptune_query.internal.retrieval.attribute_types import File as IFile |
30 | 38 | from neptune_query.internal.retrieval.attribute_types import ( |
31 | 39 | FileSeriesAggregations, |
|
38 | 46 | ) |
39 | 47 | from neptune_query.internal.retrieval.attribute_values import AttributeValue |
40 | 48 | from neptune_query.internal.retrieval.metrics import FloatPointValue |
41 | | -from neptune_query.internal.retrieval.search import ContainerType |
| 49 | +from neptune_query.internal.retrieval.search import ( |
| 50 | + ContainerType, |
| 51 | + ExperimentSysAttrs, |
| 52 | +) |
42 | 53 | from neptune_query.internal.retrieval.series import SeriesValue |
43 | 54 | from neptune_query.types import File as OFile |
44 | 55 | from neptune_query.types import Histogram as OHistogram |
@@ -1195,3 +1206,95 @@ def test_create_files_dataframe_index_name_attribute_conflict(): |
1195 | 1206 | expected_df.columns.names = ["attribute"] |
1196 | 1207 | expected_df.index.names = [index_column_name, "step"] |
1197 | 1208 | assert_frame_equal(dataframe, expected_df) |
| 1209 | + |
| 1210 | + |
| 1211 | +@pytest.mark.parametrize("duplicate_variant", [(2, 1, 1), (1, 2, 1), (1, 1, 2), (2, 2, 2)]) |
| 1212 | +@pytest.mark.parametrize("include_time", [None, "absolute"]) |
| 1213 | +def test_fetch_series_duplicate_values(duplicate_variant, include_time): |
| 1214 | + # given |
| 1215 | + project = ProjectIdentifier("project") |
| 1216 | + context.set_api_token("irrelevant") |
| 1217 | + experiments = [ExperimentSysAttrs(sys_id=SysId("sysid0"), sys_name=SysName("irrelevant"))] |
| 1218 | + attributes = [AttributeDefinition(name="attribute0", type="irrelevant")] |
| 1219 | + run_attribute_definitions = [ |
| 1220 | + RunAttributeDefinition( |
| 1221 | + run_identifier=RunIdentifier(project_identifier=project, sys_id=experiments[0].sys_id), |
| 1222 | + attribute_definition=attributes[0], |
| 1223 | + ) |
| 1224 | + ] |
| 1225 | + |
| 1226 | + duped_values, duped_attributes, duped_pages = duplicate_variant |
| 1227 | + series_values = [ |
| 1228 | + ( |
| 1229 | + run_attribute_definitions[0], |
| 1230 | + [SeriesValue(step=i, value=f"{i}", timestamp_millis=i) for i in range(100)] * duped_values, |
| 1231 | + ) |
| 1232 | + ] * duped_attributes |
| 1233 | + |
| 1234 | + # when |
| 1235 | + with ( |
| 1236 | + patch("neptune_query.internal.composition.fetch_series.get_client") as get_client, |
| 1237 | + patch("neptune_query.internal.retrieval.search.fetch_experiment_sys_attrs") as fetch_experiment_sys_attrs, |
| 1238 | + patch( |
| 1239 | + "neptune_query.internal.retrieval.attribute_definitions.fetch_attribute_definitions_single_filter" |
| 1240 | + ) as fetch_attribute_definitions_single_filter, |
| 1241 | + patch("neptune_query.internal.retrieval.series.fetch_series_values") as fetch_series_values, |
| 1242 | + ): |
| 1243 | + get_client.return_value = None |
| 1244 | + fetch_experiment_sys_attrs.return_value = iter([util.Page(experiments)]) |
| 1245 | + fetch_attribute_definitions_single_filter.side_effect = lambda **kwargs: iter([util.Page(attributes)]) |
| 1246 | + fetch_series_values.return_value = iter([util.Page(series_values)] * duped_pages) |
| 1247 | + |
| 1248 | + df = npt.fetch_series( |
| 1249 | + project=project, |
| 1250 | + experiments="ignored", |
| 1251 | + attributes=AttributeFilter(name="ignored"), |
| 1252 | + include_time=include_time, |
| 1253 | + ) |
| 1254 | + |
| 1255 | + # then |
| 1256 | + assert df.shape == (100, 1 if not include_time else 2) |
| 1257 | + |
| 1258 | + |
| 1259 | +@pytest.mark.parametrize("include_time", [None, "absolute"]) |
| 1260 | +def test_fetch_metrics_duplicate_values(include_time): |
| 1261 | + # given |
| 1262 | + project = ProjectIdentifier("project") |
| 1263 | + context.set_api_token("irrelevant") |
| 1264 | + experiments = [ExperimentSysAttrs(sys_id=SysId("sysid0"), sys_name=SysName("irrelevant"))] |
| 1265 | + attributes = [AttributeDefinition(name="attribute0", type="float_series")] |
| 1266 | + run_attribute_definitions = [ |
| 1267 | + RunAttributeDefinition( |
| 1268 | + run_identifier=RunIdentifier(project_identifier=project, sys_id=experiments[0].sys_id), |
| 1269 | + attribute_definition=attributes[0], |
| 1270 | + ) |
| 1271 | + ] |
| 1272 | + series_values = { |
| 1273 | + run_attribute_definitions[0]: [SeriesValue(step=i, value=float(i), timestamp_millis=i) for i in range(100)] * 2 |
| 1274 | + } |
| 1275 | + |
| 1276 | + # when |
| 1277 | + with ( |
| 1278 | + patch("neptune_query.internal.composition.fetch_metrics.get_client") as get_client, |
| 1279 | + patch("neptune_query.internal.retrieval.search.fetch_experiment_sys_attrs") as fetch_experiment_sys_attrs, |
| 1280 | + patch( |
| 1281 | + "neptune_query.internal.retrieval.attribute_definitions.fetch_attribute_definitions_single_filter" |
| 1282 | + ) as fetch_attribute_definitions_single_filter, |
| 1283 | + patch( |
| 1284 | + "neptune_query.internal.composition.fetch_metrics.fetch_multiple_series_values" |
| 1285 | + ) as fetch_multiple_series_values, |
| 1286 | + ): |
| 1287 | + get_client.return_value = None |
| 1288 | + fetch_experiment_sys_attrs.return_value = iter([util.Page(experiments)]) |
| 1289 | + fetch_attribute_definitions_single_filter.side_effect = lambda **kwargs: iter([util.Page(attributes)]) |
| 1290 | + fetch_multiple_series_values.return_value = series_values |
| 1291 | + |
| 1292 | + df = npt.fetch_metrics( |
| 1293 | + project=project, |
| 1294 | + experiments="ignored", |
| 1295 | + attributes=AttributeFilter(name="ignored"), |
| 1296 | + include_time=include_time, |
| 1297 | + ) |
| 1298 | + |
| 1299 | + # then |
| 1300 | + assert df.shape == (100, 1 if not include_time else 2) |
0 commit comments