|
| 1 | +from enum import Enum |
1 | 2 | from types import SimpleNamespace |
2 | 3 |
|
3 | 4 | import pytest |
|
10 | 11 | VERSION = SimpleNamespace(V6_0="6.0", V1_0="1.0") |
11 | 12 | PSEUDOPOTENTIAL_TYPE = SimpleNamespace(PAW="paw", NC="nc", NC_FR="nc-fr", US="us") |
12 | 13 | MODEL_TYPE = SimpleNamespace(DFT="dft") |
| 14 | +SUBTYPE = SimpleNamespace(GGA="gga", LDA="lda", HYBRID="hybrid", OTHER="other", INVALID="invalid_subtype") |
| 15 | +FUNCTIONAL = SimpleNamespace( |
| 16 | + PBE="pbe", |
| 17 | + PBESOL="pbesol", |
| 18 | + PW91="pw91", |
| 19 | + PZ="pz", |
| 20 | + PW="pw", |
| 21 | + VWN="vwn", |
| 22 | + OTHER="other", |
| 23 | + INVALID="invalid_functional", |
| 24 | +) |
| 25 | +GGA_FUNCTIONALS = {"PBE": FUNCTIONAL.PBE, "PBESOL": FUNCTIONAL.PBESOL, "PW91": FUNCTIONAL.PW91, |
| 26 | + "OTHER": FUNCTIONAL.OTHER} |
| 27 | +LDA_FUNCTIONALS = {"PZ": FUNCTIONAL.PZ, "PW": FUNCTIONAL.PW, "VWN": FUNCTIONAL.VWN, "OTHER": FUNCTIONAL.OTHER} |
| 28 | +EXPECTED_MODEL_BY_PARAMETERS_GGA = {"type": MODEL.DFT, "subtype": SUBTYPE.GGA, "functional": FUNCTIONAL.PBE} |
| 29 | +EXPECTED_MODEL_BY_PARAMETERS_LDA = {"type": MODEL.DFT, "subtype": SUBTYPE.LDA, "functional": FUNCTIONAL.PZ} |
13 | 30 |
|
14 | 31 |
|
15 | 32 | @pytest.mark.parametrize( |
@@ -77,3 +94,65 @@ def test_get_default_model_type_for_application(application, expected): |
77 | 94 | assert result == expected |
78 | 95 | else: |
79 | 96 | assert result is None |
| 97 | + |
| 98 | + |
| 99 | +@pytest.mark.parametrize( |
| 100 | + "model_type,expected_subtypes", |
| 101 | + [ |
| 102 | + ("dft", {"GGA": "gga", "LDA": "lda", "HYBRID": "hybrid", "OTHER": "other"}), |
| 103 | + ("invalid_model", {}), |
| 104 | + ], |
| 105 | +) |
| 106 | +def test_get_subtypes_by_model_type(model_type, expected_subtypes): |
| 107 | + subtypes = ModelTreeStandata.get_subtypes_by_model_type(model_type) |
| 108 | + assert issubclass(subtypes, Enum) |
| 109 | + assert len(list(subtypes)) == len(expected_subtypes) |
| 110 | + for enum_name, expected_value in expected_subtypes.items(): |
| 111 | + assert hasattr(subtypes, enum_name) |
| 112 | + assert getattr(subtypes, enum_name).value == expected_value |
| 113 | + |
| 114 | + |
| 115 | +@pytest.mark.parametrize( |
| 116 | + "model_type,subtype_input,use_string,expected_functionals,excluded_functionals", |
| 117 | + [ |
| 118 | + (MODEL.DFT, SUBTYPE.LDA, False, LDA_FUNCTIONALS, [FUNCTIONAL.PBE]), |
| 119 | + (MODEL.DFT, SUBTYPE.GGA, False, GGA_FUNCTIONALS, [FUNCTIONAL.PZ]), |
| 120 | + (MODEL.DFT, SUBTYPE.LDA, True, LDA_FUNCTIONALS, [FUNCTIONAL.PBE]), |
| 121 | + ], |
| 122 | +) |
| 123 | +def test_get_functionals_by_subtype(model_type, subtype_input, use_string, expected_functionals, excluded_functionals): |
| 124 | + if use_string: |
| 125 | + subtype_arg = subtype_input |
| 126 | + else: |
| 127 | + subtypes = ModelTreeStandata.get_subtypes_by_model_type(model_type) |
| 128 | + subtype_arg = getattr(subtypes, subtype_input.upper()) |
| 129 | + |
| 130 | + functionals = ModelTreeStandata.get_functionals_by_subtype(model_type, subtype_arg) |
| 131 | + assert issubclass(functionals, Enum) |
| 132 | + |
| 133 | + for enum_name, expected_value in expected_functionals.items(): |
| 134 | + assert hasattr(functionals, enum_name) |
| 135 | + assert getattr(functionals, enum_name).value == expected_value |
| 136 | + |
| 137 | + functional_values = [f.value for f in functionals] |
| 138 | + for excluded in excluded_functionals: |
| 139 | + assert excluded not in functional_values |
| 140 | + |
| 141 | + |
| 142 | +@pytest.mark.parametrize( |
| 143 | + "type,subtype,functional,expected", |
| 144 | + [ |
| 145 | + (MODEL.DFT, SUBTYPE.GGA, FUNCTIONAL.PBE, EXPECTED_MODEL_BY_PARAMETERS_GGA), |
| 146 | + (MODEL.DFT, SUBTYPE.LDA, FUNCTIONAL.PZ, EXPECTED_MODEL_BY_PARAMETERS_LDA), |
| 147 | + (MODEL.DFT, SUBTYPE.GGA, None, EXPECTED_MODEL_BY_PARAMETERS_GGA), |
| 148 | + (MODEL.DFT, None, None, EXPECTED_MODEL_BY_PARAMETERS_GGA), |
| 149 | + (MODEL.DFT, None, FUNCTIONAL.PBE, EXPECTED_MODEL_BY_PARAMETERS_GGA), |
| 150 | + (MODEL.DFT, SUBTYPE.LDA, None, EXPECTED_MODEL_BY_PARAMETERS_LDA), |
| 151 | + (MODEL.INVALID, None, None, {}), |
| 152 | + (MODEL.DFT, SUBTYPE.INVALID, None, {"type": MODEL.DFT}), |
| 153 | + (MODEL.DFT, SUBTYPE.GGA, FUNCTIONAL.INVALID, EXPECTED_MODEL_BY_PARAMETERS_GGA), |
| 154 | + ], |
| 155 | +) |
| 156 | +def test_get_model_by_parameters(type, subtype, functional, expected): |
| 157 | + result = ModelTreeStandata.get_model_by_parameters(type, subtype, functional) |
| 158 | + assert result == expected |
0 commit comments