Skip to content

Commit eaa452c

Browse files
committed
writable check
1 parent 688625a commit eaa452c

File tree

2 files changed

+113
-98
lines changed

2 files changed

+113
-98
lines changed

src/aiida/cmdline/commands/cmd_archive.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,13 @@ def inspect(ctx, archive, version, meta_data, database):
137137
)
138138
@options.DRY_RUN(help='Determine entities to export, but do not create the archive.')
139139
@click.option(
140-
'--base-tmp-dir',
141-
help='Determine entities to export, but do not create the archive. Deprecated, please use `--dry-run` instead.',
140+
'--tmp-dir',
141+
type=click.Path(exists=True, file_okay=False, dir_okay=True, writable=True, path_type=Path),
142+
help='Directory to use for temporary files during archive creation. '
143+
'If not specified, a temporary directory will be created in the same directory as the output file '
144+
'with a \'.aiida-export-\' prefix. This parameter is useful when the output directory has limited '
145+
'space or when you want to use a specific filesystem (e.g., faster storage) for temporary operations. '
146+
'The directory must exist and be writable.',
142147
)
143148
@decorators.with_dbenv()
144149
def create(

src/aiida/tools/archive/create.py

Lines changed: 106 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
stored in a single file.
1313
"""
1414

15+
import os
1516
import shutil
1617
import tempfile
1718
from datetime import datetime
@@ -59,7 +60,7 @@ def create_archive(
5960
compression: int = 6,
6061
test_run: bool = False,
6162
backend: Optional[StorageBackend] = None,
62-
temp_dir: Optional[Union[str, Path]] = None,
63+
tmp_dir: Optional[Union[str, Path]] = None,
6364
**traversal_rules: bool,
6465
) -> Path:
6566
"""Export AiiDA data to an archive file.
@@ -187,22 +188,23 @@ def querybuilder():
187188
}
188189

189190
# Handle temporary directory configuration
190-
if temp_dir is not None:
191-
temp_dir = Path(temp_dir)
192-
if not temp_dir.exists():
193-
msg = f"Specified temporary directory '{temp_dir}' does not exist"
191+
if tmp_dir is not None:
192+
tmp_dir = Path(tmp_dir)
193+
if not tmp_dir.exists():
194+
msg = f"Specified temporary directory '{tmp_dir}' does not exist"
194195
raise ArchiveExportError(msg)
195-
if not temp_dir.is_dir():
196-
msg = f"Specified temporary directory '{temp_dir}' is not a directory"
196+
if not tmp_dir.is_dir():
197+
msg = f"Specified temporary directory '{tmp_dir}' is not a directory"
197198
raise ArchiveExportError(msg)
198199
# Check if directory is writable
199-
if not temp_dir.is_writable():
200-
msg = f"Specified temporary directory '{temp_dir}' is not writable"
200+
# Taken from: https://stackoverflow.com/a/2113511
201+
if not os.access(tmp_dir, os.W_OK | os.X_OK):
202+
msg = f"Specified temporary directory '{tmp_dir}' is not writable"
201203
raise ArchiveExportError()
202204
tmp_prefix = None # Use default tempfile prefix
203205
else:
204206
# Create temporary directory in the same folder as the output file
205-
temp_dir = filename.parent
207+
tmp_dir = filename.parent
206208
tmp_prefix = '.aiida-export-'
207209

208210
initial_summary = get_init_summary(
@@ -310,93 +312,101 @@ def querybuilder():
310312
# We create in a temp dir then move to final place at end,
311313
# so that the user cannot end up with a half written archive on errors
312314
# import ipdb; ipdb.set_trace()
313-
temp_dir = Path('/mount') # or whatever directory you want to use
314-
with tempfile.TemporaryDirectory(dir=temp_dir, prefix=tmp_prefix) as tmpdir:
315-
# NOTE: Add the `tmp_prefix` to the directory or file?
316-
tmp_filename = Path(tmpdir) / 'export.zip'
317-
with archive_format.open(tmp_filename, mode='x', compression=compression) as writer:
318-
# add metadata
319-
writer.update_metadata(
320-
{
321-
'ctime': datetime.now().isoformat(),
322-
'creation_parameters': {
323-
'entities_starting_set': None
324-
if entities is None
325-
else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique},
326-
'include_authinfos': include_authinfos,
327-
'include_comments': include_comments,
328-
'include_logs': include_logs,
329-
'graph_traversal_rules': full_traversal_rules,
330-
},
331-
}
332-
)
333-
# stream entity data to the archive
334-
with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress:
335-
for etype, ids in entity_ids.items():
336-
if etype == EntityTypes.NODE and strip_checkpoints:
337-
338-
def transform(row):
339-
data = row['entity']
340-
if data.get('node_type', '').startswith('process.'):
341-
data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None)
342-
return data
343-
else:
344-
345-
def transform(row):
346-
return row['entity']
347-
348-
progress.set_description_str(f'Archiving database: {etype.value}s')
349-
if ids:
350-
for nrows, rows in batch_iter(
351-
querybuilder()
352-
.append(
353-
entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**']
354-
)
355-
.iterdict(batch_size=batch_size),
356-
batch_size,
357-
transform,
358-
):
359-
writer.bulk_insert(etype, rows)
360-
progress.update(nrows)
361-
362-
# stream links
363-
progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s')
364-
365-
def transform(d):
366-
return {
367-
'input_id': d.source_id,
368-
'output_id': d.target_id,
369-
'label': d.link_label,
370-
'type': d.link_type,
315+
tmp_dir = Path('/mount') # or whatever directory you want to use
316+
317+
try:
318+
with tempfile.TemporaryDirectory(dir=tmp_dir, prefix=tmp_prefix) as tmpdir:
319+
tmp_filename = Path(tmpdir) / 'export.zip'
320+
with archive_format.open(tmp_filename, mode='x', compression=compression) as writer:
321+
# add metadata
322+
writer.update_metadata(
323+
{
324+
'ctime': datetime.now().isoformat(),
325+
'creation_parameters': {
326+
'entities_starting_set': None
327+
if entities is None
328+
else {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique},
329+
'include_authinfos': include_authinfos,
330+
'include_comments': include_comments,
331+
'include_logs': include_logs,
332+
'graph_traversal_rules': full_traversal_rules,
333+
},
371334
}
372-
373-
for nrows, rows in batch_iter(link_data, batch_size, transform):
374-
writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True)
375-
progress.update(nrows)
376-
del link_data # release memory
377-
378-
# stream group_nodes
379-
progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s')
380-
381-
def transform(d):
382-
return {'dbgroup_id': d[0], 'dbnode_id': d[1]}
383-
384-
for nrows, rows in batch_iter(group_nodes, batch_size, transform):
385-
writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True)
386-
progress.update(nrows)
387-
del group_nodes # release memory
388-
389-
# stream node repository files to the archive
390-
if entity_ids[EntityTypes.NODE]:
391-
_stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size)
392-
393-
EXPORT_LOGGER.report('Finalizing archive creation...')
394-
395-
if filename.exists():
396-
filename.unlink()
397-
398-
filename.parent.mkdir(parents=True, exist_ok=True)
399-
shutil.move(tmp_filename, filename)
335+
)
336+
# stream entity data to the archive
337+
with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress:
338+
for etype, ids in entity_ids.items():
339+
if etype == EntityTypes.NODE and strip_checkpoints:
340+
341+
def transform(row):
342+
data = row['entity']
343+
if data.get('node_type', '').startswith('process.'):
344+
data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None)
345+
return data
346+
else:
347+
348+
def transform(row):
349+
return row['entity']
350+
351+
progress.set_description_str(f'Archiving database: {etype.value}s')
352+
if ids:
353+
for nrows, rows in batch_iter(
354+
querybuilder()
355+
.append(
356+
entity_type_to_orm[etype], filters={'id': {'in': ids}}, tag='entity', project=['**']
357+
)
358+
.iterdict(batch_size=batch_size),
359+
batch_size,
360+
transform,
361+
):
362+
writer.bulk_insert(etype, rows)
363+
progress.update(nrows)
364+
365+
# stream links
366+
progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s')
367+
368+
def transform(d):
369+
return {
370+
'input_id': d.source_id,
371+
'output_id': d.target_id,
372+
'label': d.link_label,
373+
'type': d.link_type,
374+
}
375+
376+
for nrows, rows in batch_iter(link_data, batch_size, transform):
377+
writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True)
378+
progress.update(nrows)
379+
del link_data # release memory
380+
381+
# stream group_nodes
382+
progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s')
383+
384+
def transform(d):
385+
return {'dbgroup_id': d[0], 'dbnode_id': d[1]}
386+
387+
for nrows, rows in batch_iter(group_nodes, batch_size, transform):
388+
writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True)
389+
progress.update(nrows)
390+
del group_nodes # release memory
391+
392+
# stream node repository files to the archive
393+
if entity_ids[EntityTypes.NODE]:
394+
_stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size)
395+
396+
EXPORT_LOGGER.report('Finalizing archive creation...')
397+
398+
if filename.exists():
399+
filename.unlink()
400+
401+
filename.parent.mkdir(parents=True, exist_ok=True)
402+
shutil.move(tmp_filename, filename)
403+
except OSError as e:
404+
if e.errno == 28: # No space left on device
405+
raise ArchiveExportError(
406+
f"Insufficient disk space in temporary directory '{tmp_dir}'. "
407+
f"Consider using --tmp-dir to specify a location with more available space."
408+
) from e
409+
raise ArchiveExportError(f"Failed to create temporary directory: {e}") from e
400410

401411
EXPORT_LOGGER.report('Archive created successfully')
402412

0 commit comments

Comments
 (0)