Skip to content

Commit d866856

Browse files
authored
Merge pull request #90 from Exabyte-io/feature/SOF-7768
Feature/SOF-7768
2 parents 8ae0e81 + c3a866d commit d866856

File tree

10 files changed

+389
-3
lines changed

10 files changed

+389
-3
lines changed

src/py/mat3ra/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from collections import defaultdict
2+
from typing import Dict, List
3+
4+
from .base import Standata, StandataData
5+
from .data.applications import applications_data
6+
7+
8+
class ApplicationStandata(Standata):
9+
data_dict: Dict = applications_data
10+
data: StandataData = StandataData(data_dict)
11+
12+
@classmethod
13+
def list_all(cls) -> Dict[str, List[dict]]:
14+
"""
15+
Lists all applications with their versions and build information and prints in a human-readable format.
16+
Returns a dict grouped by application name.
17+
"""
18+
grouped = defaultdict(list)
19+
for app in cls.get_as_list():
20+
version_info = {
21+
"version": app.get("version"),
22+
"build": app.get("build"),
23+
}
24+
if app.get("isLicensed"):
25+
version_info["isLicensed"] = True
26+
grouped[app.get("name")].append(version_info)
27+
28+
lines = []
29+
for app_name in sorted(grouped.keys()):
30+
for info in grouped[app_name]:
31+
licensed = " (licensed)" if info.get("isLicensed") else ""
32+
lines.append(f"{app_name}:\n version: {info['version']}, build: {info['build']}{licensed}")
33+
34+
print("\n".join(lines))
35+
return dict(grouped)
36+
37+

src/py/mat3ra/standata/base.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from enum import Enum
23
from typing import Dict, List
34

45
import pandas as pd
@@ -211,8 +212,11 @@ def get_by_name_first_match(cls, name: str) -> dict:
211212
name: Name of the entity.
212213
"""
213214
matching_filenames = cls.data.standataConfig.get_filenames_by_regex(name)
214-
return cls.data.filesMapByName.get_objects_by_filenames(matching_filenames)[0]
215-
215+
objects = cls.data.filesMapByName.get_objects_by_filenames(matching_filenames)
216+
if not objects:
217+
raise ValueError(f"No matches found for name '{name}'")
218+
return objects[0]
219+
216220
@classmethod
217221
def get_by_categories(cls, *tags: str) -> List[dict]:
218222
"""
@@ -246,3 +250,37 @@ def get_by_name_and_categories(cls, name: str, *tags: str) -> dict:
246250
raise ValueError(f"No matches found for name '{name}' and categories {tags}")
247251

248252
return cls.data.filesMapByName.get_objects_by_filenames(matching_filenames)[0]
253+
254+
@classmethod
255+
def _create_filtered_data(cls, filenames: List[str]) -> StandataData:
256+
filtered_files_map = {k: v for k, v in cls.data.filesMapByName.dictionary.items() if k in filenames}
257+
filtered_entities = [e for e in cls.data.standataConfig.entities if e.filename in filenames]
258+
return StandataData({
259+
"filesMapByName": filtered_files_map,
260+
"standataConfig": {
261+
"categories": cls.data.standataConfig.categories,
262+
"entities": [{"filename": e.filename, "categories": e.categories} for e in filtered_entities]
263+
}
264+
})
265+
266+
@classmethod
267+
def _normalize_enum_name(cls, name: str) -> str:
268+
return name.upper().replace("-", "_")
269+
270+
@classmethod
271+
def _create_enum_from_values(cls, values: List[str], enum_name: str) -> type[Enum]:
272+
enum_dict = {cls._normalize_enum_name(value): value for value in values}
273+
return Enum(enum_name, enum_dict)
274+
275+
@classmethod
276+
def filter_by_name(cls, name: str) -> "Standata":
277+
matching_filenames = cls.data.standataConfig.get_filenames_by_regex(name)
278+
filtered_data = cls._create_filtered_data(matching_filenames)
279+
return type(cls.__name__, (cls,), {"data": filtered_data})
280+
281+
@classmethod
282+
def filter_by_tags(cls, *tags: str) -> "Standata":
283+
categories = cls.data.standataConfig.convert_tags_to_categories_list(*tags)
284+
matching_filenames = cls.data.standataConfig.get_filenames_by_categories(*categories)
285+
filtered_data = cls._create_filtered_data(matching_filenames)
286+
return type(cls.__name__, (cls,), {"data": filtered_data})

src/py/mat3ra/standata/model_tree.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from typing import Any, Dict, List, Optional
23

34
from mat3ra.esse.models.method.categorized_method import SlugifiedEntry
@@ -32,3 +33,47 @@ def get_default_model_type_for_application(self, application: Dict[str, Any]) ->
3233
tree = self.get_tree_by_application_name_and_version(name, application.get("version", ""))
3334
keys = list(tree.keys())
3435
return keys[0] if keys else None
36+
37+
@classmethod
38+
def get_subtypes_by_model_type(cls, model_type: str) -> type[Enum]:
39+
model_tree = MODEL_TREE.get(model_type, {})
40+
subtypes = list(model_tree.keys())
41+
return cls._create_enum_from_values(subtypes, f"{model_type.upper()}Subtypes")
42+
43+
@classmethod
44+
def get_functionals_by_subtype(cls, model_type: str, subtype_enum: Enum) -> type[Enum]:
45+
model_tree = MODEL_TREE.get(model_type, {})
46+
subtype_value = subtype_enum.value if isinstance(subtype_enum, Enum) else subtype_enum
47+
subtype_tree = model_tree.get(subtype_value, {})
48+
functionals = subtype_tree.get("functionals", [])
49+
enum_name = f"{model_type.upper()}{cls._normalize_enum_name(subtype_value)}Functionals"
50+
return cls._create_enum_from_values(functionals, enum_name)
51+
52+
@classmethod
53+
def get_default_subtype(cls, model_tree: Dict[str, Any]) -> Optional[str]:
54+
subtypes = [key for key in model_tree.keys() if key not in ["refiners", "modifiers", "methods"]]
55+
return subtypes[0] if subtypes else None
56+
57+
@classmethod
58+
def get_model_by_parameters(cls, type: str, subtype: Optional[str], functional: Optional[str]) -> Dict[str, Any]:
59+
model_tree = MODEL_TREE.get(type, {})
60+
if not model_tree:
61+
return {}
62+
63+
result = {"type": type}
64+
65+
resolved_subtype = subtype or cls.get_default_subtype(model_tree)
66+
subtype_tree = model_tree.get(resolved_subtype, {}) if resolved_subtype else {}
67+
if not subtype_tree:
68+
return result
69+
70+
result["subtype"] = resolved_subtype
71+
72+
functionals_from_tree = subtype_tree.get("functionals", [])
73+
if functionals_from_tree:
74+
if functional and functional in functionals_from_tree:
75+
result["functional"] = functional
76+
else:
77+
result["functional"] = functionals_from_tree[0]
78+
79+
return result
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Dict
2+
3+
from .base import Standata, StandataData
4+
from .data.subworkflows import subworkflows_data
5+
6+
7+
class SubworkflowStandata(Standata):
8+
data_dict: Dict = subworkflows_data
9+
data: StandataData = StandataData(data_dict)
10+
11+
@classmethod
12+
def filter_by_application(cls, application: str) -> "SubworkflowStandata":
13+
return cls.filter_by_tags(application)
14+

src/py/mat3ra/standata/workflows.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44
from .data.workflows import workflows_data
55

66

7-
class Workflows(Standata):
7+
class WorkflowStandata(Standata):
88
data_dict: Dict = workflows_data
99
data: StandataData = StandataData(data_dict)
10+
11+
@classmethod
12+
def filter_by_application(cls, application: str) -> "WorkflowStandata":
13+
return cls.filter_by_tags(application)
14+
15+
@classmethod
16+
def filter_by_application_config(cls, application_config: Dict) -> "WorkflowStandata":
17+
application_name = application_config.get("name", "")
18+
return cls.filter_by_application(application_name)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from mat3ra.standata.data.applications import applications_data
2+
from mat3ra.standata.applications import ApplicationStandata
3+
4+
5+
def test_get_by_name():
6+
application = ApplicationStandata.get_by_name_first_match("espresso")
7+
assert type(application) == dict
8+
assert application["name"] == "espresso"
9+
assert application["version"] == "6.3"
10+
11+
12+
def test_get_by_categories():
13+
applications = ApplicationStandata.get_by_categories("quantum-mechanical")
14+
assert isinstance(applications, list)
15+
assert applications[0]["name"] == "espresso"
16+
17+
18+
def test_get_application_data():
19+
application = applications_data["filesMapByName"]["espresso/espresso_gnu_6.3.json"]
20+
assert type(application) == dict
21+
assert application["name"] == "espresso"
22+
assert application["version"] == "6.3"
23+
24+
25+
def test_get_by_name_and_categories():
26+
application = ApplicationStandata.get_by_name_and_categories("vasp", "quantum-mechanical")
27+
assert type(application) == dict
28+
assert application["name"] == "vasp"
29+
assert application["version"] == "5.4.4"
30+
31+
32+
def test_list_all():
33+
applications = ApplicationStandata.list_all()
34+
assert isinstance(applications, dict)
35+
assert len(applications) >= 1
36+
assert "espresso" in applications
37+
assert isinstance(applications["espresso"], list)
38+
assert len(applications["espresso"]) >= 1
39+
assert isinstance(applications["espresso"][0], dict)
40+
assert "version" in applications["espresso"][0]
41+
assert "build" in applications["espresso"][0]
42+
assert applications["espresso"][0]["version"] == "6.3"
43+
assert applications["espresso"][0]["build"] == "GNU"
44+
45+
def test_get_as_list():
46+
applications_list = ApplicationStandata.get_as_list()
47+
assert isinstance(applications_list, list)
48+
assert len(applications_list) >= 1
49+
assert isinstance(applications_list[0], dict)
50+
assert applications_list[0]["name"] == "espresso"
51+

tests/py/unit/test_model_tree.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from types import SimpleNamespace
23

34
import pytest
@@ -10,6 +11,22 @@
1011
VERSION = SimpleNamespace(V6_0="6.0", V1_0="1.0")
1112
PSEUDOPOTENTIAL_TYPE = SimpleNamespace(PAW="paw", NC="nc", NC_FR="nc-fr", US="us")
1213
MODEL_TYPE = SimpleNamespace(DFT="dft")
14+
SUBTYPE = SimpleNamespace(GGA="gga", LDA="lda", HYBRID="hybrid", OTHER="other", INVALID="invalid_subtype")
15+
FUNCTIONAL = SimpleNamespace(
16+
PBE="pbe",
17+
PBESOL="pbesol",
18+
PW91="pw91",
19+
PZ="pz",
20+
PW="pw",
21+
VWN="vwn",
22+
OTHER="other",
23+
INVALID="invalid_functional",
24+
)
25+
GGA_FUNCTIONALS = {"PBE": FUNCTIONAL.PBE, "PBESOL": FUNCTIONAL.PBESOL, "PW91": FUNCTIONAL.PW91,
26+
"OTHER": FUNCTIONAL.OTHER}
27+
LDA_FUNCTIONALS = {"PZ": FUNCTIONAL.PZ, "PW": FUNCTIONAL.PW, "VWN": FUNCTIONAL.VWN, "OTHER": FUNCTIONAL.OTHER}
28+
EXPECTED_MODEL_BY_PARAMETERS_GGA = {"type": MODEL.DFT, "subtype": SUBTYPE.GGA, "functional": FUNCTIONAL.PBE}
29+
EXPECTED_MODEL_BY_PARAMETERS_LDA = {"type": MODEL.DFT, "subtype": SUBTYPE.LDA, "functional": FUNCTIONAL.PZ}
1330

1431

1532
@pytest.mark.parametrize(
@@ -77,3 +94,65 @@ def test_get_default_model_type_for_application(application, expected):
7794
assert result == expected
7895
else:
7996
assert result is None
97+
98+
99+
@pytest.mark.parametrize(
100+
"model_type,expected_subtypes",
101+
[
102+
("dft", {"GGA": "gga", "LDA": "lda", "HYBRID": "hybrid", "OTHER": "other"}),
103+
("invalid_model", {}),
104+
],
105+
)
106+
def test_get_subtypes_by_model_type(model_type, expected_subtypes):
107+
subtypes = ModelTreeStandata.get_subtypes_by_model_type(model_type)
108+
assert issubclass(subtypes, Enum)
109+
assert len(list(subtypes)) == len(expected_subtypes)
110+
for enum_name, expected_value in expected_subtypes.items():
111+
assert hasattr(subtypes, enum_name)
112+
assert getattr(subtypes, enum_name).value == expected_value
113+
114+
115+
@pytest.mark.parametrize(
116+
"model_type,subtype_input,use_string,expected_functionals,excluded_functionals",
117+
[
118+
(MODEL.DFT, SUBTYPE.LDA, False, LDA_FUNCTIONALS, [FUNCTIONAL.PBE]),
119+
(MODEL.DFT, SUBTYPE.GGA, False, GGA_FUNCTIONALS, [FUNCTIONAL.PZ]),
120+
(MODEL.DFT, SUBTYPE.LDA, True, LDA_FUNCTIONALS, [FUNCTIONAL.PBE]),
121+
],
122+
)
123+
def test_get_functionals_by_subtype(model_type, subtype_input, use_string, expected_functionals, excluded_functionals):
124+
if use_string:
125+
subtype_arg = subtype_input
126+
else:
127+
subtypes = ModelTreeStandata.get_subtypes_by_model_type(model_type)
128+
subtype_arg = getattr(subtypes, subtype_input.upper())
129+
130+
functionals = ModelTreeStandata.get_functionals_by_subtype(model_type, subtype_arg)
131+
assert issubclass(functionals, Enum)
132+
133+
for enum_name, expected_value in expected_functionals.items():
134+
assert hasattr(functionals, enum_name)
135+
assert getattr(functionals, enum_name).value == expected_value
136+
137+
functional_values = [f.value for f in functionals]
138+
for excluded in excluded_functionals:
139+
assert excluded not in functional_values
140+
141+
142+
@pytest.mark.parametrize(
143+
"type,subtype,functional,expected",
144+
[
145+
(MODEL.DFT, SUBTYPE.GGA, FUNCTIONAL.PBE, EXPECTED_MODEL_BY_PARAMETERS_GGA),
146+
(MODEL.DFT, SUBTYPE.LDA, FUNCTIONAL.PZ, EXPECTED_MODEL_BY_PARAMETERS_LDA),
147+
(MODEL.DFT, SUBTYPE.GGA, None, EXPECTED_MODEL_BY_PARAMETERS_GGA),
148+
(MODEL.DFT, None, None, EXPECTED_MODEL_BY_PARAMETERS_GGA),
149+
(MODEL.DFT, None, FUNCTIONAL.PBE, EXPECTED_MODEL_BY_PARAMETERS_GGA),
150+
(MODEL.DFT, SUBTYPE.LDA, None, EXPECTED_MODEL_BY_PARAMETERS_LDA),
151+
(MODEL.INVALID, None, None, {}),
152+
(MODEL.DFT, SUBTYPE.INVALID, None, {"type": MODEL.DFT}),
153+
(MODEL.DFT, SUBTYPE.GGA, FUNCTIONAL.INVALID, EXPECTED_MODEL_BY_PARAMETERS_GGA),
154+
],
155+
)
156+
def test_get_model_by_parameters(type, subtype, functional, expected):
157+
result = ModelTreeStandata.get_model_by_parameters(type, subtype, functional)
158+
assert result == expected
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from types import SimpleNamespace
2+
3+
from mat3ra.standata.data.subworkflows import subworkflows_data
4+
from mat3ra.standata.subworkflows import SubworkflowStandata
5+
6+
APP = SimpleNamespace(ESPRESSO="espresso")
7+
SUBWORKFLOW = SimpleNamespace(
8+
SEARCH_NAME="pw_scf",
9+
FILENAME="espresso/pw_scf.json",
10+
EXACT_NAME="Preliminary SCF Calculation",
11+
)
12+
13+
14+
def test_get_by_name():
15+
subworkflow = SubworkflowStandata.get_by_name_first_match(SUBWORKFLOW.SEARCH_NAME)
16+
assert type(subworkflow) == dict
17+
assert "name" in subworkflow
18+
assert SUBWORKFLOW.EXACT_NAME in subworkflow["name"]
19+
20+
21+
def test_get_by_categories():
22+
subworkflows = SubworkflowStandata.get_by_categories(APP.ESPRESSO)
23+
assert isinstance(subworkflows, list)
24+
assert len(subworkflows) >= 1
25+
assert isinstance(subworkflows[0], dict)
26+
27+
28+
def test_get_subworkflow_data():
29+
subworkflow = subworkflows_data["filesMapByName"][SUBWORKFLOW.FILENAME]
30+
assert type(subworkflow) == dict
31+
assert "name" in subworkflow
32+
assert subworkflow["name"] == SUBWORKFLOW.EXACT_NAME
33+
34+
35+
def test_get_by_name_and_categories():
36+
subworkflow = SubworkflowStandata.get_by_name_and_categories(SUBWORKFLOW.SEARCH_NAME, APP.ESPRESSO)
37+
assert type(subworkflow) == dict
38+
assert "name" in subworkflow
39+
assert APP.ESPRESSO in str(subworkflow.get("application", {})).lower() or APP.ESPRESSO in str(subworkflow)
40+
41+
42+
def test_get_as_list():
43+
subworkflows_list = SubworkflowStandata.get_as_list()
44+
assert isinstance(subworkflows_list, list)
45+
assert len(subworkflows_list) >= 1
46+
assert isinstance(subworkflows_list[0], dict)
47+
assert "name" in subworkflows_list[0]
48+
49+
50+
def test_filter_by_application_and_get_by_name():
51+
subworkflow = SubworkflowStandata.filter_by_application(APP.ESPRESSO).get_by_name_first_match(
52+
SUBWORKFLOW.SEARCH_NAME)
53+
assert type(subworkflow) == dict
54+
assert "name" in subworkflow
55+
assert subworkflow["name"] == SUBWORKFLOW.EXACT_NAME
56+
assert APP.ESPRESSO in str(subworkflow.get("application", {})).lower()

0 commit comments

Comments
 (0)