Skip to content
This repository was archived by the owner on Jul 20, 2025. It is now read-only.

Commit 009a0bd

Browse files
committed
Refactoring and improving the README
1 parent b117c86 commit 009a0bd

File tree

6 files changed

+178
-101
lines changed

6 files changed

+178
-101
lines changed

README.md

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,16 @@ Live demo
1111
Rationale
1212
------
1313

14-
This is the proof of concept, how a relatively unsophisticated statistical model (namely, _random forest regressor_) trained on the large MPDS dataset predicts a set of physical properties from the only crystalline structure. Similarly to _ab initio_, this method could be called _ab datum_. (Note however that the simulation of physical properties with a comparable precision normally takes days, weeks or even months, whereas the present prediction method takes less than a second.) A crystal structure in either CIF or POSCAR format is required. The following physical properties are predicted:
14+
This is the proof of concept, how a relatively unsophisticated statistical model (namely, _random forest regressor_) trained on the large MPDS dataset predicts a set of physical properties from the only crystalline structure. Similarly to _ab initio_, this method could be called _ab datum_. (Note however that the simulation of physical properties with a comparable precision normally takes days, weeks or even months, whereas the present method takes less than a second!) A crystal structure in either CIF or POSCAR format is required. The following physical properties are predicted:
1515

1616
- isothermal bulk modulus
1717
- enthalpy of formation
1818
- heat capacity at constant pressure
1919
- melting temperature
20+
- Debye temperature
21+
- Seebeck coefficient
22+
- linear thermal expansion coefficient
23+
- band gap (or its absense, _i.e._ whether a crystal is conductor or insulator)
2024

2125
Installation
2226
------
@@ -34,24 +38,26 @@ Currently only *Python 2* is supported (*Python 3* support is coming).
3438
Preparation
3539
------
3640

37-
The model is trained on the MPDS data using the MPDS API and the script `ml_mpds.py`. Some subset of the full MPDS data is opened and possible to obtain via MPDS API [for free](https://mpds.io/open-data-api).
41+
The model is trained on the MPDS data using the MPDS API and the scripts `train_regressor.py` and `train_classifier.py`. Some subset of the full MPDS data is opened and possible to obtain via MPDS API [for free](https://mpds.io/open-data-api).
3842

3943
Architecture and usage
4044
------
4145

42-
This is the client-server application. The client is not required although, and it is possible to employ the server code as a standalone command-line application. The client is used for a convenience only. The client and the server communicate using HTTP. Any client able to execute HTTP requests is supported, be it a `curl` command-line client or rich web-browser user interface. As an example of the latter, a simple HTML5 app `index.html` is supplied. Server part is a Flask app, loading the pre-trained ML models:
46+
Can be used either as a standalone command-line application or as a client-server application. In the latter case, the client and the server communicate over HTTP, and any client able to execute HTTP requests is supported, be it a `curl` command-line client or rich web-browser user interface. As an example of the latter, a simple HTML5 app `index.html` is supplied in the `webassets` folder. Server part is a Flask app:
4347

4448
```python
45-
python index.py /tmp/path_to_model_one /tmp/path_to_model_two
49+
python mpds_ml_labs/app.py
4650
```
4751

48-
Web-browser user interface is then available under `http://localhost:5000`. To serve the requests the development Flask server is used. Therefore an _AS-IS_ deployment in an online environment without the suitable WSGI container is highly discouraged. Serving of the ML models is very simple. For the production environments under high load it is recommended to follow e.g. [TensorFlow Serving](https://www.tensorflow.org/serving).
52+
Web-browser user interface is then available under `http://localhost:5000`. By default, to serve the requests the development Flask server is used. Therefore an _AS-IS_ deployment in an online environment without the suitable WSGI container is **highly discouraged**. For the production environments under the high load it is recommended to use something like [TensorFlow Serving](https://www.tensorflow.org/serving).
4953

5054

5155
Used descriptor and model details
5256
------
5357

54-
The term _descriptor_ stands for the compact information-rich representation, allowing the convenient mathematical treatment of the encoded complex data (_i.e._ crystalline structure). Any crystalline structure is populated to a certain relatively big fixed volume of minimum one cubic nanometer. Then the descriptor is constructed using the periodic numbers of atoms and the lengths of their radius-vectors. The details are in the file `prediction.py`. As a machine-learning model an ensemble of decision trees ([random forest regressor](http://scikit-learn.org/stable/modules/ensemble.html)) is used, as implemented in [scikit-learn](http://scikit-learn.org) Python machine-learning toolkit. The whole MPDS dataset is used for training. In order to estimate the prediction quality, the metrics of _mean absolute error_ and _R2 coefficient of determination_ are used. The evaluation process is repeated at least 30 times to achieve a statistical reliability.
58+
The term _descriptor_ stands for the compact information-rich representation, allowing the convenient mathematical treatment of the encoded complex data (_i.e._ crystalline structure). Any crystalline structure is populated to a certain relatively big fixed volume of minimum one cubic nanometer. Then the descriptor is constructed using the periodic numbers of atoms and the lengths of their radius-vectors. The details are in the file `mpds_ml_labs/prediction.py`.
59+
60+
As a machine-learning model an ensemble of decision trees ([random forest regressor](http://scikit-learn.org/stable/modules/ensemble.html)) is used, as implemented in [scikit-learn](http://scikit-learn.org) Python machine-learning toolkit. The whole MPDS dataset can be used for training. In order to estimate the prediction quality of the _regressor_ model, the metrics of _mean absolute error_ and _R2 coefficient of determination_ are used. In order to estimate the prediction quality of the _classifier_ model (binary case), the simple error percentage is used (`(false positives + false negatives)/all outcome`). The evaluation process is repeated at least 30 times to achieve a statistical reliability.
5561

5662
API
5763
------
@@ -91,5 +97,7 @@ License
9197
Citation
9298
------
9399

94-
Please feel free to cite:
95-
- Blokhin E, Villars P, PAULING FILE and MPDS materials data infrastructure, in preparation, 2018
100+
[![DOI](https://zenodo.org/badge/110734326.svg)](https://zenodo.org/badge/latestdoi/110734326)
101+
102+
Also please feel free to cite:
103+
- Blokhin E, Villars P, PAULING FILE and MPDS materials data infrastructure, in preparation, **2018**

data/settings.ini.sample

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ serve_ui = true
33
ml_models =
44
/path_to_models/model_one.pkl
55
/path_to_models/model_two.pkl
6+
api_key =
7+
api_endpoint =

mpds_ml_labs/app.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88
from struct_utils import detect_format, poscar_to_ase, symmetrize, get_formula
99
from cif_utils import cif_to_ase, ase_to_eq_cif
10-
from prediction import ase_to_ml_model, get_legend, load_ml_model
10+
from prediction import ase_to_prediction, get_legend, load_ml_models
1111
from common import SERVE_UI, ML_MODELS
1212

1313

1414
app_labs = Blueprint('app_labs', __name__)
1515
static_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../webassets'))
16-
active_ml_model = None
16+
active_ml_models = None
1717

1818
def fmt_msg(msg, http_code=400):
1919
return Response('{"error":"%s"}' % msg, content_type='application/json', status=http_code)
@@ -85,7 +85,7 @@ def predict():
8585
if error:
8686
return fmt_msg(error)
8787

88-
prediction, error = ase_to_ml_model(ase_obj, active_ml_model)
88+
prediction, error = ase_to_prediction(ase_obj, active_ml_models)
8989
if error:
9090
return fmt_msg(error)
9191

@@ -110,11 +110,11 @@ def predict():
110110
if __name__ == '__main__':
111111
if sys.argv[1:]:
112112
print("Models to load:\n" + "\n".join(sys.argv[1:]))
113-
active_ml_model = load_ml_model(sys.argv[1:])
113+
active_ml_models = load_ml_models(sys.argv[1:])
114114

115115
elif ML_MODELS:
116116
print("Models to load:\n" + "\n".join(ML_MODELS))
117-
active_ml_model = load_ml_model(ML_MODELS)
117+
active_ml_models = load_ml_models(ML_MODELS)
118118

119119
else:
120120
print("No models to load")

mpds_ml_labs/common.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@
99

1010
if os.path.exists(config_path):
1111
config.read(config_path)
12+
1213
SERVE_UI = config.get('mpds_ml_labs', 'serve_ui')
13-
ML_MODELS = [path.strip() for path in filter(
14-
None,
15-
config.get('mpds_ml_labs', 'ml_models').split()
16-
)]
14+
ML_MODELS = config.get('mpds_ml_labs', 'ml_models') or ''
15+
API_KEY = config.get('mpds_ml_labs', 'api_key')
16+
API_ENDPOINT = config.get('mpds_ml_labs', 'api_endpoint')
17+
18+
ML_MODELS = [
19+
path.strip() for path in filter(None, ML_MODELS.split())
20+
]
21+
1722
else:
1823
SERVE_UI = True
1924
ML_MODELS = []
25+
API_KEY = None
26+
API_ENDPOINT = None

mpds_ml_labs/prediction.py

Lines changed: 123 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,16 @@
55

66
import numpy as np
77

8+
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
9+
from sklearn.model_selection import train_test_split
10+
from sklearn.metrics import mean_absolute_error, r2_score, confusion_matrix
811

9-
prop_semantics = {
12+
13+
prop_models = {
1014
'w': {
11-
'name': 'band gap for direct transition',
15+
'name': 'band gap',
1216
'units': 'eV',
13-
'symbol': 'e<sub>dir.</sub>',
17+
'symbol': 'e<sub>dir. or indir.</sub>',
1418
'rounding': 1,
1519
'interval': [0.01, 20]
1620
},
@@ -124,35 +128,35 @@ def get_descriptor(ase_obj, kappa=None, overreach=False):
124128
return np.array(DV).flatten()
125129

126130

127-
def load_ml_model(prop_model_files):
128-
ml_model = {}
131+
def load_ml_models(prop_model_files):
132+
ml_models = {}
129133
for n, file_name in enumerate(prop_model_files, start=1):
130134
if not os.path.exists(file_name):
131135
print("No file %s" % file_name)
132136
continue
133137

134138
basename = file_name.split(os.sep)[-1]
135-
if basename.startswith('ml') and basename[3:4] == '_' and basename[2:3] in prop_semantics:
139+
if basename.startswith('ml') and basename[3:4] == '_' and basename[2:3] in prop_models:
136140
prop_id = basename[2:3]
137-
print("Detected property %s in file %s" % (prop_semantics[prop_id]['name'], basename))
141+
print("Detected property %s in file %s" % (prop_models[prop_id]['name'], basename))
138142
else:
139143
prop_id = str(n)
140144
print("No property name detected in file %s" % basename)
141145

142146
with open(file_name, 'rb') as f:
143147
model = cPickle.load(f)
144148
if hasattr(model, 'predict') and hasattr(model, 'metadata'):
145-
ml_model[prop_id] = model
149+
ml_models[prop_id] = model
146150
print("Model metadata: %s" % model.metadata)
147151

148-
print("Loaded property models: %s" % len(ml_model))
149-
return ml_model
152+
print("Loaded property models: %s" % len(ml_models))
153+
return ml_models
150154

151155

152156
def get_legend(pred_dict):
153157
legend = {}
154158
for key in pred_dict.keys():
155-
legend[key] = prop_semantics.get(key, {
159+
legend[key] = prop_models.get(key, {
156160
'name': 'Unspecified property ' + str(key),
157161
'units': 'arb.u.',
158162
'symbol': 'P' + str(key),
@@ -161,32 +165,127 @@ def get_legend(pred_dict):
161165
return legend
162166

163167

164-
def ase_to_ml_model(ase_obj, ml_model):
168+
def ase_to_prediction(ase_obj, ml_models):
169+
"""
170+
Execute all the regressor models againts a given structure desriptor;
171+
the results of the "w" regressor model will depend on the
172+
output of the binary classifier model
173+
"""
165174
result = {}
166175
descriptor = get_descriptor(ase_obj, overreach=True)
167176
d_dim = len(descriptor)
177+
should_invoke_clfr = 'w' in prop_models.keys()
178+
179+
# testing
180+
if not ml_models:
181+
result = {prop_id: {'value': 42, 'mae': 0, 'r2': 0} for prop_id in prop_models.keys()}
168182

169-
if not ml_model: # testing
170-
return {prop_id: {'value': 42, 'mae': 0, 'r2': 0} for prop_id in prop_semantics.keys()}, None
183+
if should_invoke_clfr:
184+
result['w'] = {'value': 0, 'mae': 0, 'r2': 0}
171185

172-
for prop_id, regr in ml_model.items(): # production
186+
# production
187+
for prop_id, model in ml_models.items():
173188

174-
if d_dim < regr.n_features_:
189+
if d_dim < model.n_features_:
175190
continue
176-
elif d_dim > regr.n_features_:
177-
d_input = descriptor[:regr.n_features_]
191+
elif d_dim > model.n_features_:
192+
d_input = descriptor[:model.n_features_]
178193
else:
179194
d_input = descriptor[:]
180195

181196
try:
182-
prediction = regr.predict([d_input])[0]
197+
prediction = model.predict([d_input])[0]
183198
except Exception as e:
184199
return None, str(e)
185200

186-
result[prop_id] = {
187-
'value': round(prediction, prop_semantics[prop_id]['rounding']),
188-
'mae': round(regr.metadata['mae'], prop_semantics[prop_id]['rounding']),
189-
'r2': regr.metadata['r2']
190-
}
201+
# classifier
202+
if model.metadata.get('error_percentage'):
203+
204+
if should_invoke_clfr:
205+
206+
if prediction == 0:
207+
result['w'] = {'value': 0, 'mae': 0, 'r2': 0}
208+
209+
# regressor
210+
else:
211+
if prop_id not in prop_models or \
212+
(prop_id == 'w' and prop_id in result):
213+
continue
214+
215+
result[prop_id] = {
216+
'value': round(prediction, prop_models[prop_id]['rounding']),
217+
'mae': round(model.metadata['mae'], prop_models[prop_id]['rounding']),
218+
'r2': model.metadata['r2']
219+
}
191220

192221
return result, None
222+
223+
224+
def get_regr(a=None, b=None):
225+
226+
if not a: a = 100
227+
if not b: b = 2
228+
229+
return RandomForestRegressor(
230+
n_estimators=a,
231+
max_features=b,
232+
max_depth=None,
233+
min_samples_split=2, # recommended value
234+
min_samples_leaf=5, # recommended value
235+
bootstrap=True, # recommended value
236+
n_jobs=-1
237+
)
238+
239+
240+
def get_clfr(a=None, b=None):
241+
242+
if not a: a = 100
243+
if not b: b = 2
244+
245+
return RandomForestClassifier(
246+
n_estimators=a,
247+
max_features=b,
248+
max_depth=None,
249+
min_samples_split=2, # recommended value
250+
min_samples_leaf=5, # recommended value
251+
bootstrap=True, # recommended value
252+
n_jobs=-1
253+
)
254+
255+
256+
def estimate_regr_quality(algo, args, values, attempts=30, nsamples=0.33):
257+
258+
results = []
259+
260+
for _ in range(attempts):
261+
X_train, X_test, y_train, y_test = train_test_split(args, values, test_size=nsamples)
262+
algo.fit(X_train, y_train)
263+
264+
prediction = algo.predict(X_test)
265+
266+
mae = mean_absolute_error(y_test, prediction)
267+
r2 = r2_score(y_test, prediction)
268+
results.append([mae, r2])
269+
270+
results = list(map(list, zip(*results))) # transpose
271+
272+
avg_mae = np.median(results[0])
273+
avg_r2 = np.median(results[1])
274+
return avg_mae, avg_r2
275+
276+
277+
def estimate_clfr_quality(algo, args, values, attempts=30, nsamples=0.33):
278+
279+
results = []
280+
281+
for _ in range(attempts):
282+
X_train, X_test, y_train, y_test = train_test_split(args, values, test_size=nsamples)
283+
algo.fit(X_train, y_train)
284+
285+
prediction = algo.predict(X_test)
286+
287+
tn, fp, fn, tp = confusion_matrix(y_test, prediction).ravel()
288+
error_percentage = (fp + fn)/(tn + fp + fn + tp)
289+
results.append(error_percentage)
290+
291+
return np.median(results)

0 commit comments

Comments
 (0)