Skip to content
This repository was archived by the owner on Jul 20, 2025. It is now read-only.

Commit 27fbbd1

Browse files
committed
Include classifier output in tests
1 parent 009a0bd commit 27fbbd1

File tree

2 files changed

+22
-18
lines changed

2 files changed

+22
-18
lines changed

mpds_ml_labs/test_app.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88

99
from mpds_client import MPDSDataRetrieval, APIError
1010

11-
from prediction import prop_semantics
11+
from prediction import prop_models
1212
from struct_utils import detect_format, poscar_to_ase, symmetrize, get_formula, sgn_to_crsystem
1313
from cif_utils import cif_to_ase
14+
from common import API_KEY, API_ENDPOINT
1415

1516

1617
req = httplib2.Http()
17-
client = MPDSDataRetrieval()
18+
client = MPDSDataRetrieval(api_key=API_KEY, endpoint=API_ENDPOINT)
1819

1920
def make_request(address, data={}, httpverb='POST', headers={}):
2021

@@ -59,29 +60,29 @@ def make_request(address, data={}, httpverb='POST', headers={}):
5960
raise RuntimeError(answer['error'])
6061

6162
formulae_categ, lattices_categ = get_formula(ase_obj), sgn_to_crsystem(ase_obj.info['spacegroup'].no)
62-
for prop_id, pdata in prop_semantics.items():
63+
for prop_id, pdata in prop_models.items():
6364
try:
6465
resp = client.get_dataframe({
6566
'formulae': formulae_categ,
6667
'lattices': lattices_categ,
6768
'props': pdata['name']
6869
})
6970
except APIError as e:
70-
prop_semantics[prop_id]['factual'] = None
71+
prop_models[prop_id]['factual'] = None
7172
if e.code == 1:
7273
continue
7374
else:
7475
raise
7576

7677
resp['Value'] = resp['Value'].astype('float64') # to treat values out of bounds given as str
7778
resp = resp[resp['Units'] == pdata['units']]
78-
prop_semantics[prop_id]['factual'] = np.median(resp['Value'])
79+
prop_models[prop_id]['factual'] = np.median(resp['Value'])
7980

8081
for prop_id, pdata in answer['prediction'].items():
8182
print("{0:40} = {1:6}, factual {2:8} (MAE = {3:4}), {4}".format(
82-
prop_semantics[prop_id]['name'],
83-
pdata['value'],
84-
prop_semantics[prop_id]['factual'] or 'absent',
83+
prop_models[prop_id]['name'],
84+
'conductor' if pdata['value'] == 0 and prop_id == 'w' else pdata['value'],
85+
prop_models[prop_id]['factual'] or 'absent',
8586
pdata['mae'],
86-
prop_semantics[prop_id]['units']
87+
prop_models[prop_id]['units']
8788
))

mpds_ml_labs/test_ml.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,27 @@
33

44
from struct_utils import detect_format, poscar_to_ase, symmetrize
55
from cif_utils import cif_to_ase
6-
from prediction import ase_to_ml_model, load_ml_model, prop_semantics
6+
from prediction import ase_to_prediction, load_ml_models, prop_models
77
from common import ML_MODELS, DATA_PATH
88

99

1010
models, structures = [], []
1111

1212
if sys.argv[1:]:
1313
inputs = [f for f in sys.argv[1:] if os.path.isfile(f)]
14-
models, structures = \
15-
[f for f in inputs if f.endswith('.pkl')], [f for f in inputs if not f.endswith('.pkl')]
14+
models, structures = [
15+
f for f in inputs if f.endswith('.pkl')
16+
], [
17+
f for f in inputs if not f.endswith('.pkl')
18+
]
1619

1720
if not models:
1821
models = ML_MODELS
1922

2023
if not structures:
21-
structures = [os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH) if os.path.isfile(os.path.join(DATA_PATH, f))]
24+
structures = [os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH) if os.path.isfile(os.path.join(DATA_PATH, f)) and 'settings.ini' not in f]
2225

23-
active_ml_model = load_ml_model(models)
26+
active_ml_models = load_ml_models(models)
2427

2528
for fname in structures:
2629
print
@@ -50,15 +53,15 @@
5053
print(error)
5154
continue
5255

53-
prediction, error = ase_to_ml_model(ase_obj, active_ml_model)
56+
prediction, error = ase_to_prediction(ase_obj, active_ml_models)
5457
if error:
5558
print(error)
5659
continue
5760

5861
for prop_id, pdata in prediction.items():
5962
print("{0:40} = {1:6} (MAE = {2:4}), {3}".format(
60-
prop_semantics[prop_id]['name'],
61-
pdata['value'],
63+
prop_models[prop_id]['name'],
64+
'conductor' if pdata['value'] == 0 and prop_id == 'w' else pdata['value'],
6265
pdata['mae'],
63-
prop_semantics[prop_id]['units']
66+
prop_models[prop_id]['units']
6467
))

0 commit comments

Comments
 (0)