Skip to content

Commit da39a6b

Browse files
authored
Merge pull request #138 from Exabyte-io/feature/SOF-7859
Feature/sof-7859 Points Path + PW Cutoff Providers
2 parents eafed1c + f2f6afd commit da39a6b

File tree

5 files changed

+127
-1
lines changed

5 files changed

+127
-1
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .planewave_cutoffs_context_provider import PlanewaveCutoffsContextProvider
12
from .points_grid_data_provider import PointsGridDataProvider
3+
from .points_path_data_provider import PointsPathDataProvider
24

3-
__all__ = ["PointsGridDataProvider"]
5+
__all__ = ["PlanewaveCutoffsContextProvider", "PointsGridDataProvider", "PointsPathDataProvider"]
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Any, Dict, Optional
2+
3+
from mat3ra.ade.context.context_provider import ContextProvider
4+
from mat3ra.esse.models.context_providers_directory.planewave_cutoffs_context_provider import (
5+
PlanewaveCutoffsContextProviderSchema,
6+
)
7+
from pydantic import Field
8+
9+
10+
class PlanewaveCutoffsContextProvider(PlanewaveCutoffsContextProviderSchema, ContextProvider):
11+
name: str = Field(default="cutoffs")
12+
wavefunction: Optional[float] = None
13+
density: Optional[float] = None
14+
15+
@property
16+
def is_edited_key(self) -> str:
17+
return "isCutoffsEdited"
18+
19+
@property
20+
def default_data(self) -> Dict[str, Any]:
21+
return {"wavefunction": self.wavefunction, "density": self.density}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Any, Dict, List
2+
3+
from mat3ra.ade.context.context_provider import ContextProvider
4+
from mat3ra.esse.models.context_providers_directory.points_path_data_provider import (
5+
PointsPathDataProviderSchemaItem,
6+
)
7+
from pydantic import Field
8+
9+
10+
class PointsPathDataProvider(ContextProvider):
11+
"""
12+
Context provider for k-path configuration in band structure calculations.
13+
14+
The 'point' is a high-symmetry label (e.g. "Γ", "K", "M") resolved to coordinates at render time.
15+
"""
16+
17+
name: str = Field(default="kpath")
18+
path: List[PointsPathDataProviderSchemaItem] = Field(default_factory=list)
19+
20+
@property
21+
def is_edited_key(self) -> str:
22+
return "isKpathEdited"
23+
24+
@property
25+
def default_data(self) -> List[Dict[str, Any]]:
26+
return [item.model_dump(exclude_none=True) for item in self.path]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
from mat3ra.wode.context.providers import PlanewaveCutoffsContextProvider
3+
4+
ECUTWFC = 50.0
5+
ECUTRHO = 200.0
6+
7+
CUTOFFS_DATA = {
8+
"cutoffs": {"wavefunction": ECUTWFC, "density": ECUTRHO},
9+
"isCutoffsEdited": True,
10+
}
11+
12+
13+
@pytest.mark.parametrize(
14+
"init_params,expected_wavefunction,expected_density",
15+
[
16+
({"wavefunction": ECUTWFC, "density": ECUTRHO}, ECUTWFC, ECUTRHO),
17+
({"wavefunction": ECUTWFC}, ECUTWFC, None),
18+
],
19+
)
20+
def test_planewave_cutoffs_context_provider_initialization(init_params, expected_wavefunction, expected_density):
21+
provider = PlanewaveCutoffsContextProvider(**init_params)
22+
assert provider.wavefunction == expected_wavefunction
23+
assert provider.density == expected_density
24+
25+
26+
@pytest.mark.parametrize(
27+
"init_params,expected_data",
28+
[
29+
({"wavefunction": ECUTWFC, "density": ECUTRHO, "is_edited": True}, CUTOFFS_DATA),
30+
],
31+
)
32+
def test_planewave_cutoffs_context_provider_yield_data(init_params, expected_data):
33+
provider = PlanewaveCutoffsContextProvider(**init_params)
34+
assert provider.yield_data() == expected_data
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pytest
2+
from mat3ra.wode.context.providers import PointsPathDataProvider
3+
4+
KPATH_SINGLE = [{"point": "K", "steps": 20}]
5+
KPATH_FULL = [
6+
{"point": "K", "steps": 20},
7+
{"point": "Г", "steps": 20},
8+
{"point": "M", "steps": 20},
9+
{"point": "K", "steps": 1},
10+
]
11+
12+
KPATH_DATA_SINGLE = {
13+
"kpath": KPATH_SINGLE,
14+
"isKpathEdited": True,
15+
}
16+
KPATH_DATA_FULL = {
17+
"kpath": KPATH_FULL,
18+
"isKpathEdited": True,
19+
}
20+
21+
22+
@pytest.mark.parametrize(
23+
"init_params,expected_path",
24+
[
25+
({"path": KPATH_SINGLE}, KPATH_SINGLE),
26+
({"path": KPATH_FULL}, KPATH_FULL),
27+
],
28+
)
29+
def test_points_path_data_provider_initialization(init_params, expected_path):
30+
provider = PointsPathDataProvider(**init_params)
31+
assert provider.get_data() == expected_path
32+
33+
34+
@pytest.mark.parametrize(
35+
"init_params,expected_data",
36+
[
37+
({"path": KPATH_SINGLE, "is_edited": True}, KPATH_DATA_SINGLE),
38+
({"path": KPATH_FULL, "is_edited": True}, KPATH_DATA_FULL),
39+
],
40+
)
41+
def test_points_path_data_provider_yield_data(init_params, expected_data):
42+
provider = PointsPathDataProvider(**init_params)
43+
assert provider.yield_data() == expected_data

0 commit comments

Comments
 (0)