|
12 | 12 | stored in a single file. |
13 | 13 | """ |
14 | 14 |
|
| 15 | +import os |
15 | 16 | import shutil |
16 | 17 | import tempfile |
17 | 18 | from datetime import datetime |
@@ -59,7 +60,7 @@ def create_archive( |
59 | 60 | compression: int = 6, |
60 | 61 | test_run: bool = False, |
61 | 62 | backend: Optional[StorageBackend] = None, |
62 | | - temp_dir: Optional[Union[str, Path]] = None, |
| 63 | + tmp_dir: Optional[Union[str, Path]] = None, |
63 | 64 | **traversal_rules: bool, |
64 | 65 | ) -> Path: |
65 | 66 | """Export AiiDA data to an archive file. |
@@ -187,22 +188,23 @@ def querybuilder(): |
187 | 188 | } |
188 | 189 |
|
189 | 190 | # Handle temporary directory configuration |
190 | | - if temp_dir is not None: |
191 | | - temp_dir = Path(temp_dir) |
192 | | - if not temp_dir.exists(): |
193 | | - msg = f"Specified temporary directory '{temp_dir}' does not exist" |
| 191 | + if tmp_dir is not None: |
| 192 | + tmp_dir = Path(tmp_dir) |
| 193 | + if not tmp_dir.exists(): |
| 194 | + msg = f"Specified temporary directory '{tmp_dir}' does not exist" |
194 | 195 | raise ArchiveExportError(msg) |
195 | | - if not temp_dir.is_dir(): |
196 | | - msg = f"Specified temporary directory '{temp_dir}' is not a directory" |
| 196 | + if not tmp_dir.is_dir(): |
| 197 | + msg = f"Specified temporary directory '{tmp_dir}' is not a directory" |
197 | 198 | raise ArchiveExportError(msg) |
198 | 199 | # Check if directory is writable |
199 | | - if not temp_dir.is_writable(): |
200 | | - msg = f"Specified temporary directory '{temp_dir}' is not writable" |
| 200 | + # Taken from: https://stackoverflow.com/a/2113511 |
| 201 | + if not os.access(tmp_dir, os.W_OK | os.X_OK): |
| 202 | + msg = f"Specified temporary directory '{tmp_dir}' is not writable" |
201 | 203 | raise ArchiveExportError() |
202 | 204 | tmp_prefix = None # Use default tempfile prefix |
203 | 205 | else: |
204 | 206 | # Create temporary directory in the same folder as the output file |
205 | | - temp_dir = filename.parent |
| 207 | + tmp_dir = filename.parent |
206 | 208 | tmp_prefix = '.aiida-export-' |
207 | 209 |
|
208 | 210 | initial_summary = get_init_summary( |
@@ -310,93 +312,101 @@ def querybuilder(): |
310 | 312 | # We create in a temp dir then move to final place at end, |
311 | 313 | # so that the user cannot end up with a half written archive on errors |
312 | 314 | # import ipdb; ipdb.set_trace() |
313 | | - temp_dir = Path('/mount') # or whatever directory you want to use |
314 | | - with tempfile.TemporaryDirectory(dir=temp_dir, prefix=tmp_prefix) as tmpdir: |
315 | | - # NOTE: Add the `tmp_prefix` to the directory or file? |
316 | | - tmp_filename = Path(tmpdir) / 'export.zip' |
317 | | - with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: |
318 | | - # add metadata |
319 | | - writer.update_metadata( |
320 | | - { |
321 | | - 'ctime': datetime.now().isoformat(), |
322 | | - 'creation_parameters': { |
323 | | - 'entities_starting_set': None |
324 | | - if entities is None |
325 | | - else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, |
326 | | - 'include_authinfos': include_authinfos, |
327 | | - 'include_comments': include_comments, |
328 | | - 'include_logs': include_logs, |
329 | | - 'graph_traversal_rules': full_traversal_rules, |
330 | | - }, |
331 | | - } |
332 | | - ) |
333 | | - # stream entity data to the archive |
334 | | - with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress: |
335 | | - for etype, ids in entity_ids.items(): |
336 | | - if etype == EntityTypes.NODE and strip_checkpoints: |
337 | | - |
338 | | - def transform(row): |
339 | | - data = row['entity'] |
340 | | - if data.get('node_type', '').startswith('process.'): |
341 | | - data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) |
342 | | - return data |
343 | | - else: |
344 | | - |
345 | | - def transform(row): |
346 | | - return row['entity'] |
347 | | - |
348 | | - progress.set_description_str(f'Archiving database: {etype.value}s') |
349 | | - if ids: |
350 | | - for nrows, rows in batch_iter( |
351 | | - querybuilder() |
352 | | - .append( |
353 | | - entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**'] |
354 | | - ) |
355 | | - .iterdict(batch_size=batch_size), |
356 | | - batch_size, |
357 | | - transform, |
358 | | - ): |
359 | | - writer.bulk_insert(etype, rows) |
360 | | - progress.update(nrows) |
361 | | - |
362 | | - # stream links |
363 | | - progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') |
364 | | - |
365 | | - def transform(d): |
366 | | - return { |
367 | | - 'input_id': d.source_id, |
368 | | - 'output_id': d.target_id, |
369 | | - 'label': d.link_label, |
370 | | - 'type': d.link_type, |
| 315 | + tmp_dir = Path('/mount') # or whatever directory you want to use |
| 316 | + |
| 317 | + try: |
| 318 | + with tempfile.TemporaryDirectory(dir=tmp_dir, prefix=tmp_prefix) as tmpdir: |
| 319 | + tmp_filename = Path(tmpdir) / 'export.zip' |
| 320 | + with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: |
| 321 | + # add metadata |
| 322 | + writer.update_metadata( |
| 323 | + { |
| 324 | + 'ctime': datetime.now().isoformat(), |
| 325 | + 'creation_parameters': { |
| 326 | + 'entities_starting_set': None |
| 327 | + if entities is None |
| 328 | + else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, |
| 329 | + 'include_authinfos': include_authinfos, |
| 330 | + 'include_comments': include_comments, |
| 331 | + 'include_logs': include_logs, |
| 332 | + 'graph_traversal_rules': full_traversal_rules, |
| 333 | + }, |
371 | 334 | } |
372 | | - |
373 | | - for nrows, rows in batch_iter(link_data, batch_size, transform): |
374 | | - writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) |
375 | | - progress.update(nrows) |
376 | | - del link_data # release memory |
377 | | - |
378 | | - # stream group_nodes |
379 | | - progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') |
380 | | - |
381 | | - def transform(d): |
382 | | - return {'dbgroup_id': d[0], 'dbnode_id': d[1]} |
383 | | - |
384 | | - for nrows, rows in batch_iter(group_nodes, batch_size, transform): |
385 | | - writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) |
386 | | - progress.update(nrows) |
387 | | - del group_nodes # release memory |
388 | | - |
389 | | - # stream node repository files to the archive |
390 | | - if entity_ids[EntityTypes.NODE]: |
391 | | - _stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size) |
392 | | - |
393 | | - EXPORT_LOGGER.report('Finalizing archive creation...') |
394 | | - |
395 | | - if filename.exists(): |
396 | | - filename.unlink() |
397 | | - |
398 | | - filename.parent.mkdir(parents=True, exist_ok=True) |
399 | | - shutil.move(tmp_filename, filename) |
| 335 | + ) |
| 336 | + # stream entity data to the archive |
| 337 | + with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress: |
| 338 | + for etype, ids in entity_ids.items(): |
| 339 | + if etype == EntityTypes.NODE and strip_checkpoints: |
| 340 | + |
| 341 | + def transform(row): |
| 342 | + data = row['entity'] |
| 343 | + if data.get('node_type', '').startswith('process.'): |
| 344 | + data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) |
| 345 | + return data |
| 346 | + else: |
| 347 | + |
| 348 | + def transform(row): |
| 349 | + return row['entity'] |
| 350 | + |
| 351 | + progress.set_description_str(f'Archiving database: {etype.value}s') |
| 352 | + if ids: |
| 353 | + for nrows, rows in batch_iter( |
| 354 | + querybuilder() |
| 355 | + .append( |
| 356 | + entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**'] |
| 357 | + ) |
| 358 | + .iterdict(batch_size=batch_size), |
| 359 | + batch_size, |
| 360 | + transform, |
| 361 | + ): |
| 362 | + writer.bulk_insert(etype, rows) |
| 363 | + progress.update(nrows) |
| 364 | + |
| 365 | + # stream links |
| 366 | + progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') |
| 367 | + |
| 368 | + def transform(d): |
| 369 | + return { |
| 370 | + 'input_id': d.source_id, |
| 371 | + 'output_id': d.target_id, |
| 372 | + 'label': d.link_label, |
| 373 | + 'type': d.link_type, |
| 374 | + } |
| 375 | + |
| 376 | + for nrows, rows in batch_iter(link_data, batch_size, transform): |
| 377 | + writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) |
| 378 | + progress.update(nrows) |
| 379 | + del link_data # release memory |
| 380 | + |
| 381 | + # stream group_nodes |
| 382 | + progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') |
| 383 | + |
| 384 | + def transform(d): |
| 385 | + return {'dbgroup_id': d[0], 'dbnode_id': d[1]} |
| 386 | + |
| 387 | + for nrows, rows in batch_iter(group_nodes, batch_size, transform): |
| 388 | + writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) |
| 389 | + progress.update(nrows) |
| 390 | + del group_nodes # release memory |
| 391 | + |
| 392 | + # stream node repository files to the archive |
| 393 | + if entity_ids[EntityTypes.NODE]: |
| 394 | + _stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size) |
| 395 | + |
| 396 | + EXPORT_LOGGER.report('Finalizing archive creation...') |
| 397 | + |
| 398 | + if filename.exists(): |
| 399 | + filename.unlink() |
| 400 | + |
| 401 | + filename.parent.mkdir(parents=True, exist_ok=True) |
| 402 | + shutil.move(tmp_filename, filename) |
| 403 | + except OSError as e: |
| 404 | + if e.errno == 28: # No space left on device |
| 405 | + raise ArchiveExportError( |
| 406 | + f"Insufficient disk space in temporary directory '{tmp_dir}'. " |
| 407 | + f"Consider using --tmp-dir to specify a location with more available space." |
| 408 | + ) from e |
| 409 | + raise ArchiveExportError(f"Failed to create temporary directory: {e}") from e |
400 | 410 |
|
401 | 411 | EXPORT_LOGGER.report('Archive created successfully') |
402 | 412 |
|
|
0 commit comments