|
25 | 25 | import neptune_scale.types |
26 | 26 | from neptune_api import AuthenticatedClient |
27 | 27 |
|
28 | | -from neptune_query.internal.filters import ( |
29 | | - _Attribute, |
30 | | - _Filter, |
31 | | -) |
32 | | -from neptune_query.internal.identifiers import ( |
33 | | - ProjectIdentifier, |
34 | | - SysId, |
35 | | -) |
| 28 | +from neptune_query.internal.identifiers import ProjectIdentifier |
36 | 29 | from neptune_query.internal.retrieval import search |
37 | 30 |
|
38 | 31 | IngestionHistogram = neptune_scale.types.Histogram |
@@ -120,6 +113,30 @@ def get_run_by_run_id(self: IngestedProjectData, run_id: str) -> IngestedRunData |
120 | 113 | raise ValueError(f"Run not found: {run_id}") |
121 | 114 |
|
122 | 115 |
|
| 116 | +def _wait_for_ingestion( |
| 117 | + client: AuthenticatedClient, project_identifier: ProjectIdentifier, expected_data: ProjectData |
| 118 | +) -> None: |
| 119 | + for attempt in range(1): |
| 120 | + found_runs = 0 |
| 121 | + |
| 122 | + for page in search.fetch_run_sys_ids( |
| 123 | + client=client, |
| 124 | + project_identifier=project_identifier, |
| 125 | + filter_=None, |
| 126 | + ): |
| 127 | + found_runs += len(page.items) |
| 128 | + |
| 129 | + # Extra wait to ensure data is available to query before proceeding |
| 130 | + sleep(2) |
| 131 | + |
| 132 | + if found_runs == len(expected_data.runs): |
| 133 | + return |
| 134 | + |
| 135 | + raise RuntimeError( |
| 136 | + f"Timed out waiting for data ingestion, " f"found runs: {found_runs} out of expected: {len(expected_data.runs)}" |
| 137 | + ) |
| 138 | + |
| 139 | + |
123 | 140 | def ingest_project( |
124 | 141 | *, |
125 | 142 | client: AuthenticatedClient, |
@@ -235,58 +252,6 @@ def _ingest_runs(runs_data: list[RunData], api_token: str, project_identifier: s |
235 | 252 | run.close() |
236 | 253 |
|
237 | 254 |
|
238 | | -def _wait_for_ingestion( |
239 | | - client: AuthenticatedClient, project_identifier: ProjectIdentifier, expected_data: ProjectData |
240 | | -) -> None: |
241 | | - def fetch_sys_ids(attribute_name: str, attribute_value: str) -> list[SysId]: |
242 | | - sys_ids: list[SysId] = [] |
243 | | - for page in search.fetch_run_sys_ids( |
244 | | - client=client, |
245 | | - project_identifier=project_identifier, |
246 | | - filter_=_Filter.eq(_Attribute(attribute_name, type="string"), attribute_value), |
247 | | - ): |
248 | | - for item in page.items: |
249 | | - sys_ids.append(item) |
250 | | - return sys_ids |
251 | | - |
252 | | - all_runs = expected_data.runs |
253 | | - run_ids = [run.run_id for run in all_runs if run.run_id is not None] |
254 | | - experiment_names = [run.experiment_name for run in all_runs if run.experiment_name is not None] |
255 | | - |
256 | | - for attempt in range(24): |
257 | | - found_runs = 0 |
258 | | - found_experiments = 0 |
259 | | - |
260 | | - for run_id in run_ids: |
261 | | - sys_ids_by_run_id = fetch_sys_ids(attribute_name="sys/custom_run_id", attribute_value=run_id) |
262 | | - if len(sys_ids_by_run_id) > 1: |
263 | | - raise RuntimeError(f"Expected exactly one sys_id for run_id {run_id}, got {sys_ids_by_run_id}") |
264 | | - if len(sys_ids_by_run_id) == 1: |
265 | | - found_runs += 1 |
266 | | - |
267 | | - for experiment_name in experiment_names: |
268 | | - sys_ids_by_experiment_name = fetch_sys_ids(attribute_name="sys/name", attribute_value=experiment_name) |
269 | | - if len(sys_ids_by_experiment_name) > 1: |
270 | | - raise RuntimeError( |
271 | | - f"Expected exactly one sys_id for experiment_name {experiment_name}, " |
272 | | - f"got {sys_ids_by_experiment_name}" |
273 | | - ) |
274 | | - if len(sys_ids_by_experiment_name) == 1: |
275 | | - found_experiments += 1 |
276 | | - |
277 | | - # Extra wait to ensure data is available to query before proceeding |
278 | | - sleep(2) |
279 | | - |
280 | | - if found_runs == len(run_ids) and found_experiments == len(experiment_names): |
281 | | - return |
282 | | - |
283 | | - raise RuntimeError( |
284 | | - f"Timed out waiting for data ingestion. " |
285 | | - f"Found runs: {found_runs} out of expected: {len(run_ids)}. " |
286 | | - f"Found experiments: {found_experiments} out of expected: {len(experiment_names)}." |
287 | | - ) |
288 | | - |
289 | | - |
290 | 255 | def _get_all_steps(run_data: RunData) -> Iterable[float]: |
291 | 256 | # Collect all unique steps |
292 | 257 | all_steps = set() |
|
0 commit comments