Skip to content

Commit 1a5dfe8

Browse files
committed
ShellJob: Fix RemoteData handling
The `filenames` input was not taken into account for `RemoteData` input nodes.
1 parent 189df63 commit 1a5dfe8

File tree

4 files changed

+84
-12
lines changed

4 files changed

+84
-12
lines changed

src/aiida_shell/calculations/shell.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from aiida.common.datastructures import CalcInfo, CodeInfo, FileCopyOperation
1111
from aiida.common.folders import Folder
1212
from aiida.engine import CalcJob, CalcJobProcessSpec
13-
from aiida.orm import Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type
13+
from aiida.orm import Computer, Data, Dict, FolderData, List, RemoteData, SinglefileData, to_aiida_type
1414
from aiida.parsers import Parser
1515

1616
from aiida_shell.data import EntryPointData, PickledData
@@ -281,9 +281,11 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
281281
inputs = {}
282282

283283
nodes = inputs.get('nodes', {})
284+
computer = inputs['code'].computer
284285
filenames = (inputs.get('filenames', None) or Dict()).get_dict()
285286
arguments = (inputs.get('arguments', None) or List()).get_list()
286287
outputs = (inputs.get('outputs', None) or List()).get_list()
288+
use_symlinks = inputs['metadata']['options']['use_symlinks']
287289
filename_stdin = inputs['metadata']['options'].get('filename_stdin', None)
288290
filename_stdout = inputs['metadata']['options'].get('output_filename', None)
289291
default_retrieved_temporary = list(self.DEFAULT_RETRIEVED_TEMPORARY)
@@ -300,7 +302,10 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
300302
if filename_stdin and filename_stdin in processed_arguments:
301303
processed_arguments.remove(filename_stdin)
302304

303-
remote_copy_list, remote_symlink_list = self.handle_remote_data_nodes(inputs)
305+
remote_data_nodes = {key: node for key, node in nodes.items() if isinstance(node, RemoteData)}
306+
remote_copy_list, remote_symlink_list = self.handle_remote_data_nodes(
307+
remote_data_nodes, filenames, computer, use_symlinks
308+
)
304309

305310
code_info = CodeInfo()
306311
code_info.code_uuid = inputs['code'].uuid
@@ -329,16 +334,22 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
329334
return calc_info
330335

331336
@staticmethod
332-
def handle_remote_data_nodes(inputs: dict[str, Data]) -> tuple[list[t.Any], list[t.Any]]:
333-
"""Handle a ``RemoteData`` that was passed in the ``nodes`` input.
337+
def handle_remote_data_nodes(
338+
remote_data_nodes: dict[str, RemoteData], filenames: dict[str, str], computer: Computer, use_symlinks: bool
339+
) -> tuple[list[t.Any], list[t.Any]]:
340+
"""Handle all ``RemoteData`` nodes that were passed in the ``nodes`` input.
334341
335-
:param inputs: The inputs dictionary.
342+
:param remote_data_nodes: The ``RemoteData`` input nodes.
343+
:param filenames: A dictionary of explicit filenames to use for the ``nodes`` to be written to ``dirpath``.
336344
:returns: A tuple of two lists, the ``remote_copy_list`` and the ``remote_symlink_list``.
337345
"""
338-
use_symlinks: bool = inputs['metadata']['options']['use_symlinks'] # type: ignore[index]
339-
computer_uuid = inputs['code'].computer.uuid # type: ignore[union-attr]
340-
remote_nodes = [node for node in inputs.get('nodes', {}).values() if isinstance(node, RemoteData)]
341-
instructions = [(computer_uuid, f'{node.get_remote_path()}/*', '.') for node in remote_nodes]
346+
instructions = []
347+
348+
for key, node in remote_data_nodes.items():
349+
if key in filenames:
350+
instructions.append((computer.uuid, node.get_remote_path(), filenames[key]))
351+
else:
352+
instructions.append((computer.uuid, f'{node.get_remote_path()}/*', '.'))
342353

343354
if use_symlinks:
344355
return [], instructions
@@ -407,7 +418,10 @@ def process_arguments_and_nodes(
407418
self.write_folder_data(node, dirpath, filename)
408419
argument_interpolated = argument.format(**{placeholder: filename or placeholder})
409420
elif isinstance(node, RemoteData):
410-
self.handle_remote_data(node)
421+
# Only the placeholder needs to be formatted. The content of the remote data itself is handled by the
422+
# engine through the instructions created in ``handle_remote_data_nodes``.
423+
filename = prepared_filenames[placeholder]
424+
argument_interpolated = argument.format(**{placeholder: filename or placeholder})
411425
else:
412426
argument_interpolated = argument.format(**{placeholder: str(node.value)})
413427

@@ -465,6 +479,8 @@ def prepare_filenames(self, nodes: dict[str, SinglefileData], filenames: dict[st
465479
raise RuntimeError(
466480
f'node `{key}` contains the file `{f}` which overlaps with a reserved output filename.'
467481
)
482+
elif isinstance(node, RemoteData):
483+
filename = filenames.get(key, None)
468484
else:
469485
continue
470486

tests/calculations/test_shell.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_nodes_folder_data(generate_calc_job, generate_code, tmp_path):
8888

8989
@pytest.mark.parametrize('use_symlinks', (True, False))
9090
def test_nodes_remote_data(generate_calc_job, generate_code, tmp_path, aiida_localhost, use_symlinks):
91-
"""Test the ``nodes`` input with ``RemoteData`` nodes ."""
91+
"""Test the ``nodes`` input with ``RemoteData`` nodes."""
9292
inputs = {
9393
'code': generate_code(),
9494
'arguments': [],
@@ -107,6 +107,39 @@ def test_nodes_remote_data(generate_calc_job, generate_code, tmp_path, aiida_loc
107107
assert sorted(calc_info.remote_copy_list) == [(aiida_localhost.uuid, str(tmp_path / '*'), '.')]
108108

109109

110+
def test_nodes_remote_data_filename(generate_calc_job, generate_code, tmp_path, aiida_localhost):
111+
"""Test the ``nodes`` and ``filenames`` inputs with ``RemoteData`` nodes."""
112+
remote_path_a = tmp_path / 'remote_a'
113+
remote_path_b = tmp_path / 'remote_b'
114+
remote_path_a.mkdir()
115+
remote_path_b.mkdir()
116+
(remote_path_a / 'file_a.txt').write_text('content a')
117+
(remote_path_b / 'file_b.txt').write_text('content b')
118+
remote_data_a = RemoteData(remote_path=str(remote_path_a.absolute()), computer=aiida_localhost)
119+
remote_data_b = RemoteData(remote_path=str(remote_path_b.absolute()), computer=aiida_localhost)
120+
121+
inputs = {
122+
'code': generate_code(),
123+
'arguments': ['{remote_a}'],
124+
'nodes': {
125+
'remote_a': remote_data_a,
126+
'remote_b': remote_data_b,
127+
},
128+
'filenames': {'remote_a': 'target_remote'},
129+
}
130+
dirpath, calc_info = generate_calc_job('core.shell', inputs)
131+
132+
code_info = calc_info.codes_info[0]
133+
assert code_info.cmdline_params == ['target_remote']
134+
135+
assert calc_info.remote_symlink_list == []
136+
assert sorted(calc_info.remote_copy_list) == [
137+
(aiida_localhost.uuid, str(remote_path_a / '*'), 'target_remote'),
138+
(aiida_localhost.uuid, str(remote_path_b / '*'), '.'),
139+
]
140+
assert sorted(p.name for p in dirpath.iterdir()) == []
141+
142+
110143
def test_nodes_base_types(generate_calc_job, generate_code):
111144
"""Test the ``nodes`` input with ``BaseType`` nodes ."""
112145
inputs = {

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def factory(entry_point_name='core.shell', store_provenance=False, filepath_retr
5858

5959

6060
@pytest.fixture
61-
def generate_calc_job(tmp_path):
61+
def generate_calc_job(tmp_path_factory):
6262
"""Create a :class:`aiida.engine.CalcJob` instance with the given inputs.
6363
6464
The fixture will call ``prepare_for_submission`` and return a tuple of the temporary folder that was passed to it,
@@ -81,6 +81,7 @@ def factory(
8181
which ensures that all input files are written, including those by the scheduler plugin, such as the
8282
submission script.
8383
"""
84+
tmp_path = tmp_path_factory.mktemp('calc_job_submit_dir')
8485
manager = get_manager()
8586
runner = manager.get_runner()
8687

tests/test_launch.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,28 @@ def test_nodes_remote_data(tmp_path, aiida_localhost, use_symlinks):
164164
assert (dirpath_working / 'filled' / 'file_b.txt').read_text() == 'content b'
165165

166166

167+
def test_nodes_remote_data_filename(tmp_path_factory, aiida_localhost):
168+
"""Test copying contents of a ``RemoteData`` to specific subdirectory."""
169+
dirpath_remote = tmp_path_factory.mktemp('remote')
170+
dirpath_source = dirpath_remote / 'source'
171+
dirpath_source.mkdir()
172+
(dirpath_source / 'file.txt').touch()
173+
remote_data = RemoteData(remote_path=str(dirpath_remote), computer=aiida_localhost)
174+
175+
results, node = launch_shell_job(
176+
'echo',
177+
arguments=['{remote}'],
178+
nodes={'remote': remote_data},
179+
filenames={'remote': 'sub_directory'},
180+
)
181+
assert node.is_finished_ok
182+
assert results['stdout'].get_content().strip() == 'sub_directory'
183+
dirpath_working = pathlib.Path(node.outputs.remote_folder.get_remote_path())
184+
assert (dirpath_working / 'sub_directory').is_dir()
185+
assert (dirpath_working / 'sub_directory' / 'source').is_dir()
186+
assert (dirpath_working / 'sub_directory' / 'source' / 'file.txt').is_file()
187+
188+
167189
def test_nodes_base_types():
168190
"""Test a shellfunction that specifies positional CLI arguments that are interpolated by the ``kwargs``."""
169191
nodes = {

0 commit comments

Comments
 (0)