- 
                Notifications
    
You must be signed in to change notification settings  - Fork 232
 
          Add --tmp-dir option for archive creation
          #6946
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
52325fc
              3331f6d
              0e3d38a
              58b8d91
              d3737aa
              f7a7111
              74dc0c6
              f38ecd0
              f0ae6f8
              c71d1ea
              2379d21
              4756f5e
              ea4007f
              a515785
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| 
          
            
          
           | 
    @@ -12,6 +12,7 @@ | |||||
| stored in a single file. | ||||||
| """ | ||||||
| 
     | 
||||||
| import os | ||||||
| import shutil | ||||||
| import tempfile | ||||||
| from datetime import datetime | ||||||
| 
          
            
          
           | 
    @@ -59,6 +60,7 @@ def create_archive( | |||||
| compression: int = 6, | ||||||
| test_run: bool = False, | ||||||
| backend: Optional[StorageBackend] = None, | ||||||
| tmp_dir: Optional[Union[str, Path]] = None, | ||||||
| **traversal_rules: bool, | ||||||
| ) -> Path: | ||||||
| """Export AiiDA data to an archive file. | ||||||
| 
          
            
          
           | 
    @@ -139,6 +141,11 @@ def create_archive( | |||||
| 
     | 
||||||
| :param backend: the backend to export from. If not specified, the default backend is used. | ||||||
| 
     | 
||||||
| :param tmp_dir: Location where the temporary directory will be written during archive creation. | ||||||
| The directory must exist and be writable, and defaults to the parent directory of the output file. | ||||||
| This parameter is useful when the output directory has limited space or when you want to use a specific | ||||||
| filesystem (e.g., faster storage) for temporary operations. | ||||||
| 
     | 
||||||
| :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` | ||||||
| what rule names are toggleable and what the defaults are. | ||||||
| 
     | 
||||||
| 
          
            
          
           | 
    @@ -239,7 +246,7 @@ def querybuilder(): | |||||
| entity_ids[EntityTypes.USER].add(entry.pk) | ||||||
| else: | ||||||
| raise ArchiveExportError( | ||||||
| f'I was given {entry} ({type(entry)}),' ' which is not a User, Node, Computer, or Group instance' | ||||||
| f'I was given {entry} ({type(entry)}), which is not a User, Node, Computer, or Group instance' | ||||||
| ) | ||||||
| group_nodes, link_data = _collect_required_entities( | ||||||
| querybuilder, | ||||||
| 
          
            
          
           | 
    @@ -280,94 +287,129 @@ def querybuilder(): | |||||
| 
     | 
||||||
| EXPORT_LOGGER.report(f'Creating archive with:\n{tabulate(count_summary)}') | ||||||
| 
     | 
||||||
| # Handle temporary directory configuration | ||||||
| if tmp_dir is not None: | ||||||
| tmp_dir = Path(tmp_dir) | ||||||
| if not tmp_dir.exists(): | ||||||
| EXPORT_LOGGER.warning(f"Specified temporary directory '{tmp_dir}' doesn't exist. Creating it.") | ||||||
| tmp_dir.mkdir(parents=True) | ||||||
| if not tmp_dir.is_dir(): | ||||||
| msg = f"Specified temporary directory '{tmp_dir}' is not a directory" | ||||||
| raise ArchiveExportError(msg) | ||||||
| # Check if directory is writable | ||||||
| # Taken from: https://stackoverflow.com/a/2113511 | ||||||
| if not os.access(tmp_dir, os.W_OK | os.X_OK): | ||||||
| msg = f"Specified temporary directory '{tmp_dir}' is not writable" | ||||||
| raise ArchiveExportError(msg) | ||||||
| 
     | 
||||||
| else: | ||||||
| # Create temporary directory in the same folder as the output file | ||||||
| tmp_dir = filename.parent | ||||||
| 
     | 
||||||
| # Create and open the archive for writing. | ||||||
| # We create in a temp dir then move to final place at end, | ||||||
| # so that the user cannot end up with a half written archive on errors | ||||||
| with tempfile.TemporaryDirectory() as tmpdir: | ||||||
| tmp_filename = Path(tmpdir) / 'export.zip' | ||||||
| with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: | ||||||
| # add metadata | ||||||
| writer.update_metadata( | ||||||
| { | ||||||
| 'ctime': datetime.now().isoformat(), | ||||||
| 'creation_parameters': { | ||||||
| 'entities_starting_set': None | ||||||
| if entities is None | ||||||
| else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, | ||||||
| 'include_authinfos': include_authinfos, | ||||||
| 'include_comments': include_comments, | ||||||
| 'include_logs': include_logs, | ||||||
| 'graph_traversal_rules': full_traversal_rules, | ||||||
| }, | ||||||
| } | ||||||
| ) | ||||||
| # stream entity data to the archive | ||||||
| with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress: | ||||||
| for etype, ids in entity_ids.items(): | ||||||
| if etype == EntityTypes.NODE and strip_checkpoints: | ||||||
| 
     | 
||||||
| def transform(row): | ||||||
| data = row['entity'] | ||||||
| if data.get('node_type', '').startswith('process.'): | ||||||
| data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) | ||||||
| return data | ||||||
| else: | ||||||
| 
     | 
||||||
| def transform(row): | ||||||
| return row['entity'] | ||||||
| 
     | 
||||||
| progress.set_description_str(f'Archiving database: {etype.value}s') | ||||||
| if ids: | ||||||
| for nrows, rows in batch_iter( | ||||||
| querybuilder() | ||||||
| .append( | ||||||
| entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**'] | ||||||
| ) | ||||||
| .iterdict(batch_size=batch_size), | ||||||
| batch_size, | ||||||
| transform, | ||||||
| ): | ||||||
| writer.bulk_insert(etype, rows) | ||||||
| progress.update(nrows) | ||||||
| 
     | 
||||||
| # stream links | ||||||
| progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') | ||||||
| 
     | 
||||||
| def transform(d): | ||||||
| return { | ||||||
| 'input_id': d.source_id, | ||||||
| 'output_id': d.target_id, | ||||||
| 'label': d.link_label, | ||||||
| 'type': d.link_type, | ||||||
| try: | ||||||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Diff is large here because of the additional try-except and the indent that goes with it (to capture disk-space errors) but the actual code inside should be the same! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
 diff --git a/src/aiida/tools/archive/create.py b/src/aiida/tools/archive/create.py
index 94ca88cd4..c4f0671d5 100644
--- a/src/aiida/tools/archive/create.py
+++ b/src/aiida/tools/archive/create.py
@@ -12,6 +12,7 @@ The archive is a subset of the provenance graph,
 stored in a single file.
 """
 
+import os
 import shutil
 import tempfile
 from datetime import datetime
@@ -59,6 +60,7 @@ def create_archive(
     compression: int = 6,
     test_run: bool = False,
     backend: Optional[StorageBackend] = None,
+    tmp_dir: Optional[Union[str, Path]] = None,
     **traversal_rules: bool,
 ) -> Path:
     """Export AiiDA data to an archive file.
@@ -139,6 +141,12 @@ def create_archive(
 
     :param backend: the backend to export from. If not specified, the default backend is used.
 
+    :param tmp_dir: Directory to use for temporary files during archive creation.
+        If not specified, a temporary directory will be created in the same directory as the output file
+        with a '.aiida-export-' prefix. This parameter is useful when the output directory has limited
+        space or when you want to use a specific filesystem (e.g., faster storage) for temporary operations.
+        The directory must exist and be writable.
+
     :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules`
         what rule names are toggleable and what the defaults are.
 
@@ -280,10 +288,32 @@ def create_archive(
 
     EXPORT_LOGGER.report(f'Creating archive with:\n{tabulate(count_summary)}')
 
+    # Handle temporary directory configuration
+    tmp_prefix = '.aiida-export-'
+    if tmp_dir is not None:
+        tmp_dir = Path(tmp_dir)
+        if not tmp_dir.exists():
+            EXPORT_LOGGER.warning(f"Specified temporary directory '{tmp_dir}' doesn't exist. Creating it.")
+            tmp_dir.mkdir(parents=True)
+        if not tmp_dir.is_dir():
+            msg = f"Specified temporary directory '{tmp_dir}' is not a directory"
+            raise ArchiveExportError(msg)
+        # Check if directory is writable
+        # Taken from: https://stackoverflow.com/a/2113511
+        if not os.access(tmp_dir, os.W_OK | os.X_OK):
+            msg = f"Specified temporary directory '{tmp_dir}' is not writable"
+            raise ArchiveExportError(msg)
+
+    else:
+        # Create temporary directory in the same folder as the output file
+        tmp_dir = filename.parent
+
     # Create and open the archive for writing.
     # We create in a temp dir then move to final place at end,
     # so that the user cannot end up with a half written archive on errors
-    with tempfile.TemporaryDirectory() as tmpdir:
+    try:
+        tmp_dir.mkdir(parents=True, exist_ok=True)
+        with tempfile.TemporaryDirectory(dir=tmp_dir, prefix=tmp_prefix) as tmpdir:
             tmp_filename = Path(tmpdir) / 'export.zip'
             with archive_format.open(tmp_filename, mode='x', compression=compression) as writer:
                 # add metadata
@@ -302,7 +332,9 @@ def create_archive(
                     }
                 )
                 # stream entity data to the archive
-            with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress:
+                with get_progress_reporter()(
+                    desc='Archiving database: ', total=sum(entity_counts.values())
+                ) as progress:
                     for etype, ids in entity_ids.items():
                         if etype == EntityTypes.NODE and strip_checkpoints:
 
@@ -359,7 +391,9 @@ def create_archive(
 
                 # stream node repository files to the archive
                 if entity_ids[EntityTypes.NODE]:
-                _stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size)
+                    _stream_repo_files(
+                        archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size
+                    )
 
                 EXPORT_LOGGER.report('Finalizing archive creation...')
 
@@ -368,6 +402,16 @@ def create_archive(
 
             filename.parent.mkdir(parents=True, exist_ok=True)
             shutil.move(tmp_filename, filename)
+    except OSError as e:
+        if e.errno == 28:  # No space left on device
+            msg = (
+                f"Insufficient disk space in temporary directory '{tmp_dir}'. "
+                f'Consider using --tmp-dir to specify a location with more available space.'
+            )
+            raise ArchiveExportError(msg) from e
+
+        msg = f'Failed to create temporary directory: {e}'
+        raise ArchiveExportError(msg) from e
 
     EXPORT_LOGGER.report('Archive created successfully')There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification @GeigerJ2! This'll help in my review <3  | 
||||||
| tmp_dir.mkdir(parents=True, exist_ok=True) | ||||||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this line added here? Everything related to checking the   | 
||||||
| with tempfile.TemporaryDirectory(dir=tmp_dir, prefix='.aiida-export-') as tmpdir: | ||||||
| tmp_filename = Path(tmpdir) / 'export.zip' | ||||||
| with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: | ||||||
| # add metadata | ||||||
| writer.update_metadata( | ||||||
| { | ||||||
| 'ctime': datetime.now().isoformat(), | ||||||
| 'creation_parameters': { | ||||||
| 'entities_starting_set': None | ||||||
| if entities is None | ||||||
| else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, | ||||||
| 'include_authinfos': include_authinfos, | ||||||
| 'include_comments': include_comments, | ||||||
| 'include_logs': include_logs, | ||||||
| 'graph_traversal_rules': full_traversal_rules, | ||||||
| }, | ||||||
| } | ||||||
| ) | ||||||
| # stream entity data to the archive | ||||||
| with get_progress_reporter()( | ||||||
| desc='Archiving database: ', total=sum(entity_counts.values()) | ||||||
| ) as progress: | ||||||
| for etype, ids in entity_ids.items(): | ||||||
| if etype == EntityTypes.NODE and strip_checkpoints: | ||||||
| 
     | 
||||||
| def transform(row): | ||||||
| data = row['entity'] | ||||||
| if data.get('node_type', '').startswith('process.'): | ||||||
| data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) | ||||||
| return data | ||||||
| else: | ||||||
| 
     | 
||||||
| def transform(row): | ||||||
| return row['entity'] | ||||||
| 
     | 
||||||
| progress.set_description_str(f'Archiving database: {etype.value}s') | ||||||
| if ids: | ||||||
| for nrows, rows in batch_iter( | ||||||
| querybuilder() | ||||||
| .append( | ||||||
| entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**'] | ||||||
| ) | ||||||
| .iterdict(batch_size=batch_size), | ||||||
| batch_size, | ||||||
| transform, | ||||||
| ): | ||||||
| writer.bulk_insert(etype, rows) | ||||||
| progress.update(nrows) | ||||||
| 
     | 
||||||
| # stream links | ||||||
| progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') | ||||||
| 
     | 
||||||
| def transform(d): | ||||||
| return { | ||||||
| 'input_id': d.source_id, | ||||||
| 'output_id': d.target_id, | ||||||
| 'label': d.link_label, | ||||||
| 'type': d.link_type, | ||||||
| } | ||||||
| 
     | 
||||||
| for nrows, rows in batch_iter(link_data, batch_size, transform): | ||||||
| writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) | ||||||
| progress.update(nrows) | ||||||
| del link_data # release memory | ||||||
| 
     | 
||||||
| # stream group_nodes | ||||||
| progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') | ||||||
| 
     | 
||||||
| def transform(d): | ||||||
| return {'dbgroup_id': d[0], 'dbnode_id': d[1]} | ||||||
| 
     | 
||||||
| for nrows, rows in batch_iter(group_nodes, batch_size, transform): | ||||||
| writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) | ||||||
| progress.update(nrows) | ||||||
| del group_nodes # release memory | ||||||
| 
     | 
||||||
| # stream node repository files to the archive | ||||||
| if entity_ids[EntityTypes.NODE]: | ||||||
| _stream_repo_files( | ||||||
| archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size | ||||||
| ) | ||||||
| 
     | 
||||||
| EXPORT_LOGGER.report('Finalizing archive creation...') | ||||||
| 
     | 
||||||
| if filename.exists(): | ||||||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know it's not part you just put in the   | 
||||||
| filename.unlink() | ||||||
| 
     | 
||||||
| filename.parent.mkdir(parents=True, exist_ok=True) | ||||||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, not your code, but something to think about: Apparently you can just write the archive to a folder that doesn't exist and the command will create it and all non-existent parents. I think that's ok. But the behaviour is then different from the   | 
||||||
| shutil.move(tmp_filename, filename) | ||||||
| except OSError as e: | ||||||
| if e.errno == 28: # No space left on device | ||||||
| msg = ( | ||||||
| f"Insufficient disk space in temporary directory '{tmp_dir}'. " | ||||||
| f'Consider using --tmp-dir to specify a location with more available space.' | ||||||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that this can also happen when the  
        Suggested change
       
    
 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still checking the code above in more detail, but what happens if there is enough space in the   | 
||||||
| ) | ||||||
| raise ArchiveExportError(msg) from e | ||||||
| 
     | 
||||||
| for nrows, rows in batch_iter(link_data, batch_size, transform): | ||||||
| writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) | ||||||
| progress.update(nrows) | ||||||
| del link_data # release memory | ||||||
| 
     | 
||||||
| # stream group_nodes | ||||||
| progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') | ||||||
| 
     | 
||||||
| def transform(d): | ||||||
| return {'dbgroup_id': d[0], 'dbnode_id': d[1]} | ||||||
| 
     | 
||||||
| for nrows, rows in batch_iter(group_nodes, batch_size, transform): | ||||||
| writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) | ||||||
| progress.update(nrows) | ||||||
| del group_nodes # release memory | ||||||
| 
     | 
||||||
| # stream node repository files to the archive | ||||||
| if entity_ids[EntityTypes.NODE]: | ||||||
| _stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size) | ||||||
| 
     | 
||||||
| EXPORT_LOGGER.report('Finalizing archive creation...') | ||||||
| 
     | 
||||||
| if filename.exists(): | ||||||
| filename.unlink() | ||||||
| 
     | 
||||||
| filename.parent.mkdir(parents=True, exist_ok=True) | ||||||
| shutil.move(tmp_filename, filename) | ||||||
| msg = f'Failed to create temporary directory: {e}' | ||||||
| raise ArchiveExportError(msg) from e | ||||||
| 
     | 
||||||
| EXPORT_LOGGER.report('Archive created successfully') | ||||||
| 
     | 
||||||
| 
          
            
          
           | 
    @@ -668,7 +710,7 @@ def _check_unsealed_nodes(querybuilder: QbType, node_ids: set[int], batch_size: | |||||
| if unsealed_node_pks: | ||||||
| raise ExportValidationError( | ||||||
| 'All ProcessNodes must be sealed before they can be exported. ' | ||||||
| f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." | ||||||
| f'Node(s) with PK(s): {", ".join(str(pk) for pk in unsealed_node_pks)} is/are not sealed.' | ||||||
| ) | ||||||
| 
     | 
||||||
| 
     | 
||||||
| 
          
            
          
           | 
    @@ -759,18 +801,18 @@ def get_init_summary( | |||||
| """Get summary for archive initialisation""" | ||||||
| parameters = [['Path', str(outfile)], ['Version', archive_version], ['Compression', compression]] | ||||||
| 
     | 
||||||
| result = f"\n{tabulate(parameters, headers=['Archive Parameters', ''])}" | ||||||
| result = f'\n{tabulate(parameters, headers=["Archive Parameters", ""])}' | ||||||
| 
     | 
||||||
| inclusions = [ | ||||||
| ['Computers/Nodes/Groups/Users', 'All' if collect_all else 'Selected'], | ||||||
| ['Computer Authinfos', include_authinfos], | ||||||
| ['Node Comments', include_comments], | ||||||
| ['Node Logs', include_logs], | ||||||
| ] | ||||||
| result += f"\n\n{tabulate(inclusions, headers=['Inclusion rules', ''])}" | ||||||
| result += f'\n\n{tabulate(inclusions, headers=["Inclusion rules", ""])}' | ||||||
| 
     | 
||||||
| if not collect_all: | ||||||
| rules_table = [[f"Follow links {' '.join(name.split('_'))}s", value] for name, value in traversal_rules.items()] | ||||||
| result += f"\n\n{tabulate(rules_table, headers=['Traversal rules', ''])}" | ||||||
| rules_table = [[f'Follow links {" ".join(name.split("_"))}s', value] for name, value in traversal_rules.items()] | ||||||
| result += f'\n\n{tabulate(rules_table, headers=["Traversal rules", ""])}' | ||||||
| 
     | 
||||||
| return result + '\n' | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This behaviour is different from the CLI command, that does expect the temporary directory to exist. It also seems to conflict with the docstring of the function.