Skip to content

Commit b618cb7

Browse files
danielhollasadityagh006agoscinski
authored
Fix QueryBuilder filtering for AbstractCode (aiidateam#6866)
Adding an entry point for `AbstractCode` to make it queryable.Users searching for `AbstractCode` (sub)classes want to get all the codes (`PortableCode`, `InstalledCode` etc.) Unfortunately, because AbstractCode was introduced later, its entry point is 'data.core.code.abstract', while the 'core.code' entry point is claimed by the Legacy Code class. So to get all the code types, including the Legacy Code, we adjust the filter to 'data.core.code'. Note, this only make sense if `subclassing` parameter is True! Fixes issue aiidateam#6687. --------- Co-authored-by: unknown <[email protected]> Co-authored-by: Alexander Goscinski <[email protected]>
1 parent 2fd4b89 commit b618cb7

File tree

5 files changed

+109
-1
lines changed

5 files changed

+109
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ requires-python = '>=3.9'
103103
'core.bool' = 'aiida.orm.nodes.data.bool:Bool'
104104
'core.cif' = 'aiida.orm.nodes.data.cif:CifData'
105105
'core.code' = 'aiida.orm.nodes.data.code.legacy:Code'
106+
'core.code.abstract' = 'aiida.orm.nodes.data.code.abstract:AbstractCode'
106107
'core.code.containerized' = 'aiida.orm.nodes.data.code.containerized:ContainerizedCode'
107108
'core.code.installed' = 'aiida.orm.nodes.data.code.installed:InstalledCode'
108109
'core.code.portable' = 'aiida.orm.nodes.data.code.portable:PortableCode'

src/aiida/orm/querybuilder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,14 @@ def _get_node_type_filter(classifiers: Classifier, subclassing: bool) -> dict:
13381338

13391339
value = classifiers.ormclass_type_string
13401340

1341+
# Users searching for `AbstractCode` (sub)classes want to get all the codes (Portable, Installed etc.)
1342+
# Unfortunately, because AbstractCode was introduced later, its entry point is 'data.core.code.abstract',
1343+
# while the 'core.code' entry point is claimed by the Legacy Code class.
1344+
# So to get all the code types, including the Legacy Code, we adjust the filter to 'data.core.code'.
1345+
# Note, this only make sense if `subclassing` parameter is True!
1346+
if value == 'data.core.code.abstract.AbstractCode.':
1347+
value = 'data.core.code.'
1348+
13411349
if not subclassing:
13421350
filters = {'==': value}
13431351
else:

tests/orm/nodes/data/test_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,11 @@ def _generate_class_instance(data_class):
154154

155155
@pytest.fixture(
156156
scope='function',
157-
params=[entry_point for entry_point in plugins.get_entry_points('aiida.data') if entry_point.name != 'core.code'],
157+
params=[
158+
entry_point
159+
for entry_point in plugins.get_entry_points('aiida.data')
160+
if entry_point.name not in ('core.code', 'core.code.abstract')
161+
],
158162
)
159163
def data_plugin(request):
160164
"""Fixture that parametrizes over all the registered entry points of the ``aiida.data`` entry point group."""
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
append_text: QbStrField('append_text', dtype=<class 'str'>, is_attribute=True)
2+
attributes: QbDictField('attributes', dtype=typing.Optional[typing.Dict[str, typing.Any]],
3+
is_attribute=False, is_subscriptable=True)
4+
computer: QbNumericField('computer', dtype=typing.Optional[int], is_attribute=False)
5+
ctime: QbNumericField('ctime', dtype=typing.Optional[datetime.datetime], is_attribute=False)
6+
default_calc_job_plugin: QbStrField('default_calc_job_plugin', dtype=typing.Optional[str],
7+
is_attribute=True)
8+
description: QbStrField('description', dtype=<class 'str'>, is_attribute=True)
9+
extras: QbDictField('extras', dtype=typing.Optional[typing.Dict[str, typing.Any]],
10+
is_attribute=False, is_subscriptable=True)
11+
label: QbStrField('label', dtype=<class 'str'>, is_attribute=True)
12+
mtime: QbNumericField('mtime', dtype=typing.Optional[datetime.datetime], is_attribute=False)
13+
node_type: QbStrField('node_type', dtype=typing.Optional[str], is_attribute=False)
14+
pk: QbNumericField('pk', dtype=typing.Optional[int], is_attribute=False)
15+
prepend_text: QbStrField('prepend_text', dtype=<class 'str'>, is_attribute=True)
16+
process_type: QbStrField('process_type', dtype=typing.Optional[str], is_attribute=False)
17+
repository_content: QbDictField('repository_content', dtype=typing.Optional[dict[str,
18+
bytes]], is_attribute=False)
19+
repository_metadata: QbDictField('repository_metadata', dtype=typing.Optional[typing.Dict[str,
20+
typing.Any]], is_attribute=False)
21+
source: QbDictField('source', dtype=typing.Optional[dict], is_attribute=True, is_subscriptable=True)
22+
use_double_quotes: QbField('use_double_quotes', dtype=<class 'bool'>, is_attribute=True)
23+
user: QbNumericField('user', dtype=typing.Optional[int], is_attribute=False)
24+
uuid: QbStrField('uuid', dtype=typing.Optional[str], is_attribute=False)
25+
with_mpi: QbField('with_mpi', dtype=typing.Optional[bool], is_attribute=True)

tests/orm/test_querybuilder.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,76 @@ def test_empty_filters(self):
852852
qb = orm.QueryBuilder().append(orm.Data, filters={'or': [{}, {}]})
853853
assert qb.count() == count
854854

855+
@pytest.mark.usefixtures('suppress_internal_deprecations')
856+
@pytest.mark.usefixtures('aiida_profile_clean')
857+
def test_abstract_code_filtering(self, aiida_localhost, aiida_code, tmp_path):
858+
"""Test that querying for AbstractCode correctly returns all code instances.
859+
860+
This tests the fix for issue #6687, where QueryBuilder couldn't find codes
861+
when looking for AbstractCode due to a node_type mismatch.
862+
"""
863+
installed_code = aiida_code(
864+
'core.code.installed',
865+
label='installed-code',
866+
computer=aiida_localhost,
867+
filepath_executable='/bin/bash',
868+
)
869+
(tmp_path / 'fake_exec').touch()
870+
portable_code = aiida_code(
871+
'core.code.portable',
872+
label='portable-code',
873+
filepath_executable='fake_exec',
874+
filepath_files=tmp_path,
875+
)
876+
legacy_code = aiida_code(
877+
'core.code',
878+
label='legacy-code',
879+
remote_computer_exec=(aiida_localhost, '/bin/bash'),
880+
)
881+
882+
qb = orm.QueryBuilder
883+
884+
# Verify specific code type queries work as expected
885+
installed_results = qb().append(orm.InstalledCode).all(flat=True)
886+
assert installed_code in installed_results
887+
assert len(installed_results) == 1
888+
889+
portable_results = qb().append(orm.PortableCode).all(flat=True)
890+
assert portable_code in portable_results
891+
assert len(portable_results) == 1
892+
893+
# Using orm.Code actually matches all codes.
894+
# for backwards compatibility reasons we will not fix this.
895+
legacy_results = qb().append(orm.Code).all(flat=True)
896+
assert legacy_code in legacy_results
897+
assert len(legacy_results) == 3
898+
899+
# Turning off subclassing should however only match the one legacy Code
900+
legacy_results = qb().append(orm.Code, subclassing=False).all(flat=True)
901+
assert legacy_code in legacy_results
902+
assert len(legacy_results) == 1
903+
904+
# AbstractCode query should find all code types
905+
abstract_results = qb().append(orm.AbstractCode).all(flat=True)
906+
assert (
907+
installed_code in abstract_results
908+
), f'InstalledCode not found with AbstractCode query. Result: {abstract_results}'
909+
assert (
910+
portable_code in abstract_results
911+
), f'PortableCode not found with AbstractCode query. Result: {abstract_results}'
912+
assert legacy_code in abstract_results, f'Code not found with AbstractCode query. Result: {abstract_results}'
913+
assert len(abstract_results) == 3
914+
915+
# AbstractCode with basic filtering
916+
qb_filtered = qb().append(orm.AbstractCode, filters={'label': 'installed-code'})
917+
filtered_results = qb_filtered.all(flat=True)
918+
assert installed_code in filtered_results
919+
assert len(filtered_results) == 1
920+
921+
# QB should find no codes if subclassing is False
922+
subclassing_off_results = qb().append(orm.AbstractCode, subclassing=False).all(flat=True)
923+
assert len(subclassing_off_results) == 0
924+
855925

856926
class TestAttributes:
857927
@pytest.mark.requires_psql

0 commit comments

Comments
 (0)