Skip to content

Commit a8210aa

Browse files
Conversion electrode + trajectory tweaks (#1041)
2 parents 76fa686 + ca4fd36 commit a8210aa

File tree

4 files changed

+22
-8
lines changed

4 files changed

+22
-8
lines changed

mp_api/client/routes/materials/electrodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,4 +173,5 @@ class ConversionElectrodeRester(BaseElectrodeRester):
173173
"elements",
174174
"stability_charge",
175175
"stability_discharge",
176+
"exclude_elements",
176177
]

mp_api/client/routes/materials/tasks.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
from datetime import datetime
44
from typing import TYPE_CHECKING
55

6+
from emmet.core.mpid import MPID, AlphaID
67
from emmet.core.tasks import CoreTaskDoc
78

89
from mp_api.client.core import BaseRester, MPRestError
910
from mp_api.client.core.utils import validate_ids
1011

1112
if TYPE_CHECKING:
13+
from typing import Any
14+
1215
from pydantic import BaseModel
1316

1417

@@ -17,17 +20,21 @@ class TaskRester(BaseRester):
1720
document_model: type[BaseModel] = CoreTaskDoc # type: ignore
1821
primary_key: str = "task_id"
1922

20-
def get_trajectory(self, task_id):
23+
def get_trajectory(self, task_id: MPID | AlphaID | str) -> list[dict[str, Any]]:
2124
"""Returns a Trajectory object containing the geometry of the
2225
material throughout a calculation. This is most useful for
2326
observing how a material relaxes during a geometry optimization.
2427
2528
Args:
26-
task_id (str): Task ID
29+
task_id (str, MPID, AlphaID): Task ID
2730
31+
Returns:
32+
list of dict representing emmet.core.trajectory.Trajectory
2833
"""
2934
traj_data = self._query_resource_data(
30-
{"task_ids": [task_id]}, suburl="trajectory/", use_document_model=False
35+
{"task_ids": [AlphaID(task_id).string]},
36+
suburl="trajectory/",
37+
use_document_model=False,
3138
)[0].get(
3239
"trajectories", None
3340
) # type: ignore

tests/materials/test_electrodes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def test_conversion_client(conversion_rester):
7777
custom_field_tests={
7878
"battery_ids": ["mp-1067_Al"],
7979
"working_ion": Element("Li"),
80-
"exclude_elements": ["Co", "O"],
8180
},
8281
sub_doc_fields=sub_doc_fields,
8382
)

tests/materials/test_tasks.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from core_function import client_search_testing
33
import pytest
44

5+
from emmet.core.mpid import MPID, AlphaID
6+
from emmet.core.trajectory import Trajectory
57
from emmet.core.utils import utcnow
68
from mp_api.client.routes.materials.tasks import TaskRester
79

@@ -52,8 +54,13 @@ def test_client(rester):
5254
)
5355

5456

55-
def test_get_trajectories(rester):
56-
trajectories = [traj for traj in rester.get_trajectory("mp-149")]
57+
@pytest.mark.parametrize("mpid", ["mp-149", MPID("mp-149"), AlphaID("mp-149")])
58+
def test_get_trajectories(rester, mpid):
59+
trajectories = [traj for traj in rester.get_trajectory(mpid)]
5760

58-
for traj in trajectories:
59-
assert ("@module", "pymatgen.core.trajectory") in traj.items()
61+
expected_model_fields = {
62+
field_name
63+
for field_name, field in Trajectory.model_fields.items()
64+
if not field.exclude
65+
}
66+
assert all(set(traj) == expected_model_fields for traj in trajectories)

0 commit comments

Comments
 (0)