diff --git a/docs/source/reference/command_line.rst b/docs/source/reference/command_line.rst index 6bc2272694..2a85592330 100644 --- a/docs/source/reference/command_line.rst +++ b/docs/source/reference/command_line.rst @@ -453,7 +453,7 @@ Below is a list with all available subcommands. --broker-host HOSTNAME Hostname for the message broker. [default: 127.0.0.1] --broker-port INTEGER Port for the message broker. [default: 5672] --broker-virtual-host TEXT Name of the virtual host for the message broker without - leading forward slash. [default: ""] + leading forward slash. --repository DIRECTORY Absolute path to the file repository. --test-profile Designate the profile to be used for running the test suite only. diff --git a/src/aiida/cmdline/commands/cmd_archive.py b/src/aiida/cmdline/commands/cmd_archive.py index 94746536a0..71b947fe88 100644 --- a/src/aiida/cmdline/commands/cmd_archive.py +++ b/src/aiida/cmdline/commands/cmd_archive.py @@ -138,6 +138,16 @@ def inspect(ctx, archive, version, meta_data, database): help='Determine entities to export, but do not create the archive. Deprecated, please use `--dry-run` instead.', ) @options.DRY_RUN(help='Determine entities to export, but do not create the archive.') +@click.option( + '--tmp-dir', + type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True, path_type=Path), + help=( + 'Location where the temporary directory will be written during archive creation.' + 'The directory must exist and be writable, and defaults to the parent directory of the output file.' + 'This parameter is useful when the output directory has limited space or when you want to use a specific' + 'filesystem (e.g., faster storage) for temporary operations.' + ), +) @decorators.with_dbenv() def create( output_file, @@ -160,6 +170,7 @@ def create( batch_size, test_run, dry_run, + tmp_dir, ): """Create an archive from all or part of a profiles's data. @@ -211,6 +222,7 @@ def create( 'compression': compress, 'batch_size': batch_size, 'test_run': dry_run, + 'tmp_dir': tmp_dir, } if AIIDA_LOGGER.level <= logging.REPORT: # type: ignore[attr-defined] @@ -327,7 +339,7 @@ class ExtrasImportCode(Enum): '--extras-mode-new', type=click.Choice(EXTRAS_MODE_NEW), default='import', - help='Specify whether to import extras of new nodes: ' 'import: import extras. ' 'none: do not import extras.', + help='Specify whether to import extras of new nodes: import: import extras. none: do not import extras.', ) @click.option( '--comment-mode', diff --git a/src/aiida/tools/archive/create.py b/src/aiida/tools/archive/create.py index e5c88cde49..c4239cbd93 100644 --- a/src/aiida/tools/archive/create.py +++ b/src/aiida/tools/archive/create.py @@ -12,6 +12,7 @@ stored in a single file. """ +import os import shutil import tempfile from datetime import datetime @@ -59,6 +60,7 @@ def create_archive( compression: int = 6, test_run: bool = False, backend: Optional[StorageBackend] = None, + tmp_dir: Optional[Union[str, Path]] = None, **traversal_rules: bool, ) -> Path: """Export AiiDA data to an archive file. @@ -139,6 +141,11 @@ def create_archive( :param backend: the backend to export from. If not specified, the default backend is used. + :param tmp_dir: Location where the temporary directory will be written during archive creation. + The directory must exist and be writable, and defaults to the parent directory of the output file. + This parameter is useful when the output directory has limited space or when you want to use a specific + filesystem (e.g., faster storage) for temporary operations. + :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` what rule names are toggleable and what the defaults are. @@ -239,7 +246,7 @@ def querybuilder(): entity_ids[EntityTypes.USER].add(entry.pk) else: raise ArchiveExportError( - f'I was given {entry} ({type(entry)}),' ' which is not a User, Node, Computer, or Group instance' + f'I was given {entry} ({type(entry)}), which is not a User, Node, Computer, or Group instance' ) group_nodes, link_data = _collect_required_entities( querybuilder, @@ -280,94 +287,129 @@ def querybuilder(): EXPORT_LOGGER.report(f'Creating archive with:\n{tabulate(count_summary)}') + # Handle temporary directory configuration + if tmp_dir is not None: + tmp_dir = Path(tmp_dir) + if not tmp_dir.exists(): + EXPORT_LOGGER.warning(f"Specified temporary directory '{tmp_dir}' doesn't exist. Creating it.") + tmp_dir.mkdir(parents=True) + if not tmp_dir.is_dir(): + msg = f"Specified temporary directory '{tmp_dir}' is not a directory" + raise ArchiveExportError(msg) + # Check if directory is writable + # Taken from: https://stackoverflow.com/a/2113511 + if not os.access(tmp_dir, os.W_OK | os.X_OK): + msg = f"Specified temporary directory '{tmp_dir}' is not writable" + raise ArchiveExportError(msg) + + else: + # Create temporary directory in the same folder as the output file + tmp_dir = filename.parent + # Create and open the archive for writing. # We create in a temp dir then move to final place at end, # so that the user cannot end up with a half written archive on errors - with tempfile.TemporaryDirectory() as tmpdir: - tmp_filename = Path(tmpdir) / 'export.zip' - with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: - # add metadata - writer.update_metadata( - { - 'ctime': datetime.now().isoformat(), - 'creation_parameters': { - 'entities_starting_set': None - if entities is None - else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, - 'include_authinfos': include_authinfos, - 'include_comments': include_comments, - 'include_logs': include_logs, - 'graph_traversal_rules': full_traversal_rules, - }, - } - ) - # stream entity data to the archive - with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress: - for etype, ids in entity_ids.items(): - if etype == EntityTypes.NODE and strip_checkpoints: - - def transform(row): - data = row['entity'] - if data.get('node_type', '').startswith('process.'): - data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) - return data - else: - - def transform(row): - return row['entity'] - - progress.set_description_str(f'Archiving database: {etype.value}s') - if ids: - for nrows, rows in batch_iter( - querybuilder() - .append( - entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**'] - ) - .iterdict(batch_size=batch_size), - batch_size, - transform, - ): - writer.bulk_insert(etype, rows) - progress.update(nrows) - - # stream links - progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') - - def transform(d): - return { - 'input_id': d.source_id, - 'output_id': d.target_id, - 'label': d.link_label, - 'type': d.link_type, + try: + tmp_dir.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory(dir=tmp_dir, prefix='.aiida-export-') as tmpdir: + tmp_filename = Path(tmpdir) / 'export.zip' + with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: + # add metadata + writer.update_metadata( + { + 'ctime': datetime.now().isoformat(), + 'creation_parameters': { + 'entities_starting_set': None + if entities is None + else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, + 'include_authinfos': include_authinfos, + 'include_comments': include_comments, + 'include_logs': include_logs, + 'graph_traversal_rules': full_traversal_rules, + }, } + ) + # stream entity data to the archive + with get_progress_reporter()( + desc='Archiving database: ', total=sum(entity_counts.values()) + ) as progress: + for etype, ids in entity_ids.items(): + if etype == EntityTypes.NODE and strip_checkpoints: + + def transform(row): + data = row['entity'] + if data.get('node_type', '').startswith('process.'): + data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) + return data + else: + + def transform(row): + return row['entity'] + + progress.set_description_str(f'Archiving database: {etype.value}s') + if ids: + for nrows, rows in batch_iter( + querybuilder() + .append( + entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**'] + ) + .iterdict(batch_size=batch_size), + batch_size, + transform, + ): + writer.bulk_insert(etype, rows) + progress.update(nrows) + + # stream links + progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') + + def transform(d): + return { + 'input_id': d.source_id, + 'output_id': d.target_id, + 'label': d.link_label, + 'type': d.link_type, + } + + for nrows, rows in batch_iter(link_data, batch_size, transform): + writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) + progress.update(nrows) + del link_data # release memory + + # stream group_nodes + progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') + + def transform(d): + return {'dbgroup_id': d[0], 'dbnode_id': d[1]} + + for nrows, rows in batch_iter(group_nodes, batch_size, transform): + writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) + progress.update(nrows) + del group_nodes # release memory + + # stream node repository files to the archive + if entity_ids[EntityTypes.NODE]: + _stream_repo_files( + archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size + ) + + EXPORT_LOGGER.report('Finalizing archive creation...') + + if filename.exists(): + filename.unlink() + + filename.parent.mkdir(parents=True, exist_ok=True) + shutil.move(tmp_filename, filename) + except OSError as e: + if e.errno == 28: # No space left on device + msg = ( + f"Insufficient disk space in temporary directory '{tmp_dir}'. " + f'Consider using --tmp-dir to specify a location with more available space.' + ) + raise ArchiveExportError(msg) from e - for nrows, rows in batch_iter(link_data, batch_size, transform): - writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) - progress.update(nrows) - del link_data # release memory - - # stream group_nodes - progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') - - def transform(d): - return {'dbgroup_id': d[0], 'dbnode_id': d[1]} - - for nrows, rows in batch_iter(group_nodes, batch_size, transform): - writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) - progress.update(nrows) - del group_nodes # release memory - - # stream node repository files to the archive - if entity_ids[EntityTypes.NODE]: - _stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size) - - EXPORT_LOGGER.report('Finalizing archive creation...') - - if filename.exists(): - filename.unlink() - - filename.parent.mkdir(parents=True, exist_ok=True) - shutil.move(tmp_filename, filename) + msg = f'Failed to create temporary directory: {e}' + raise ArchiveExportError(msg) from e EXPORT_LOGGER.report('Archive created successfully') @@ -668,7 +710,7 @@ def _check_unsealed_nodes(querybuilder: QbType, node_ids: set[int], batch_size: if unsealed_node_pks: raise ExportValidationError( 'All ProcessNodes must be sealed before they can be exported. ' - f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." + f'Node(s) with PK(s): {", ".join(str(pk) for pk in unsealed_node_pks)} is/are not sealed.' ) @@ -759,7 +801,7 @@ def get_init_summary( """Get summary for archive initialisation""" parameters = [['Path', str(outfile)], ['Version', archive_version], ['Compression', compression]] - result = f"\n{tabulate(parameters, headers=['Archive Parameters', ''])}" + result = f'\n{tabulate(parameters, headers=["Archive Parameters", ""])}' inclusions = [ ['Computers/Nodes/Groups/Users', 'All' if collect_all else 'Selected'], @@ -767,10 +809,10 @@ def get_init_summary( ['Node Comments', include_comments], ['Node Logs', include_logs], ] - result += f"\n\n{tabulate(inclusions, headers=['Inclusion rules', ''])}" + result += f'\n\n{tabulate(inclusions, headers=["Inclusion rules", ""])}' if not collect_all: - rules_table = [[f"Follow links {' '.join(name.split('_'))}s", value] for name, value in traversal_rules.items()] - result += f"\n\n{tabulate(rules_table, headers=['Traversal rules', ''])}" + rules_table = [[f'Follow links {" ".join(name.split("_"))}s', value] for name, value in traversal_rules.items()] + result += f'\n\n{tabulate(rules_table, headers=["Traversal rules", ""])}' return result + '\n' diff --git a/tests/cmdline/commands/test_archive_create.py b/tests/cmdline/commands/test_archive_create.py index 5fea646714..37d7f53f31 100644 --- a/tests/cmdline/commands/test_archive_create.py +++ b/tests/cmdline/commands/test_archive_create.py @@ -208,3 +208,17 @@ def test_info_empty_archive(run_cli_command): filename_input = get_archive_file('empty.aiida', filepath='export/migrate') result = run_cli_command(cmd_archive.archive_info, [filename_input], raises=True) assert 'archive file unreadable' in result.output + + +def test_create_tmp_dir_option(run_cli_command, tmp_path): + """Test that the --tmp-dir CLI option passes through correctly.""" + node = Dict().store() + + custom_tmp = tmp_path / 'custom_tmp' + custom_tmp.mkdir() + filename_output = tmp_path / 'archive.aiida' + + options = ['--tmp-dir', str(custom_tmp), '-N', node.pk, '--', filename_output] + + run_cli_command(cmd_archive.create, options) + assert filename_output.is_file() diff --git a/tests/tools/archive/test_simple.py b/tests/tools/archive/test_simple.py index ac6209ab95..6c1df441a3 100644 --- a/tests/tools/archive/test_simple.py +++ b/tests/tools/archive/test_simple.py @@ -18,6 +18,7 @@ from aiida.common.exceptions import IncompatibleStorageSchema, LicensingException from aiida.common.links import LinkType from aiida.tools.archive import create_archive, import_archive +from aiida.tools.archive.exceptions import ArchiveExportError @pytest.mark.parametrize('entities', ['all', 'specific']) @@ -154,3 +155,114 @@ def crashing_filter(_): with pytest.raises(LicensingException): create_archive([struct], test_run=True, forbidden_licenses=crashing_filter) + + +@pytest.mark.usefixtures('aiida_profile_clean') +def test_tmp_dir_basic(tmp_path): + """Test that tmp_dir parameter is used correctly.""" + node = orm.Int(42).store() + custom_tmp = tmp_path / 'custom_tmp' + custom_tmp.mkdir() + filename = tmp_path / 'export.aiida' + + create_archive([node], filename=filename, tmp_dir=custom_tmp) + assert filename.exists() + + +@pytest.mark.usefixtures('aiida_profile_clean') +def test_tmp_dir_file_error(tmp_path): + """Test tmp_dir validation errors.""" + + node = orm.Int(42).store() + filename = tmp_path / 'export.aiida' + + # File instead of directory + not_a_dir = tmp_path / 'file.txt' + not_a_dir.write_text('content') + with pytest.raises(ArchiveExportError, match='is not a directory'): + create_archive([node], filename=filename, tmp_dir=not_a_dir) + + +@pytest.mark.usefixtures('aiida_profile_clean') +def test_tmp_dir_disk_space_error(tmp_path): + """Test disk space error handling.""" + from unittest.mock import patch + + node = orm.Int(42).store() + custom_tmp = tmp_path / 'custom_tmp' + custom_tmp.mkdir() + filename = tmp_path / 'export.aiida' + + def mock_temp_dir_error(*args, **kwargs): + error = OSError('No space left on device') + error.errno = 28 + raise error + + with patch('tempfile.TemporaryDirectory', side_effect=mock_temp_dir_error): + with pytest.raises(ArchiveExportError, match='Insufficient disk space.*--tmp-dir'): + create_archive([node], filename=filename, tmp_dir=custom_tmp) + + +@pytest.mark.usefixtures('aiida_profile_clean') +def test_tmp_dir_auto_create(tmp_path): + """Test automatic creation of non-existent tmp_dir.""" + node = orm.Int(42).store() + filename = tmp_path / 'export.aiida' + custom_tmp = tmp_path / 'nonexistent_tmp' # Don't create it! + + create_archive([node], filename=filename, tmp_dir=custom_tmp) + assert filename.exists() + # Verify the directory was created + assert custom_tmp.exists() + + +@pytest.mark.usefixtures('aiida_profile_clean') +def test_tmp_dir_permission_error(tmp_path): + """Test tmp_dir permission validation.""" + import stat + + node = orm.Int(42).store() + filename = tmp_path / 'export.aiida' + readonly_tmp = tmp_path / 'readonly_tmp' + readonly_tmp.mkdir() + + # Make directory read-only + readonly_tmp.chmod(stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) + + try: + with pytest.raises(ArchiveExportError, match='is not writable'): + create_archive([node], filename=filename, tmp_dir=readonly_tmp) + finally: + # Restore permissions for cleanup + readonly_tmp.chmod(stat.S_IRWXU) + + +@pytest.mark.usefixtures('aiida_profile_clean') +def test_tmp_dir_default_behavior(tmp_path): + """Test default tmp_dir behavior (no tmp_dir specified).""" + node = orm.Int(42).store() + filename = tmp_path / 'export.aiida' + + # Don't specify tmp_dir - test the default path + create_archive([node], filename=filename) + assert filename.exists() + + +@pytest.mark.usefixtures('aiida_profile_clean') +def test_tmp_dir_general_os_error(tmp_path): + """Test general OS error handling.""" + from unittest.mock import patch + + node = orm.Int(42).store() + custom_tmp = tmp_path / 'custom_tmp' + custom_tmp.mkdir() + filename = tmp_path / 'export.aiida' + + def mock_temp_dir_error(*args, **kwargs): + error = OSError('Permission denied') + error.errno = 13 # Different from 28 + raise error + + with patch('aiida.tools.archive.create.tempfile.TemporaryDirectory', side_effect=mock_temp_dir_error): + with pytest.raises(ArchiveExportError, match='Failed to create temporary directory'): + create_archive([node], filename=filename, tmp_dir=custom_tmp)