Skip to content
This repository was archived by the owner on Jul 20, 2025. It is now read-only.

Commit 355a58e

Browse files
committed
refactor web-app, present settings and another test script
1 parent e571e42 commit 355a58e

File tree

6 files changed

+155
-37
lines changed

6 files changed

+155
-37
lines changed

data/settings.ini.sample

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[mpds_ml_labs]
2+
serve_ui = true
3+
ml_models =
4+
/path_to_models/model_one.pkl
5+
/path_to_models/model_two.pkl

mpds_ml_labs/app.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
from flask import Flask, Blueprint, Response, request, send_from_directory
77

8-
from cors import crossdomain
98
from struct_utils import detect_format, poscar_to_ase, symmetrize, get_formula
109
from cif_utils import cif_to_ase, ase_to_eq_cif
1110
from prediction import ase_to_ml_model, get_legend, load_ml_model
11+
from common import SERVE_UI, ML_MODELS
1212

1313

1414
app_labs = Blueprint('app_labs', __name__)
15-
ml_model = None
15+
static_path = os.path.realpath(os.path.join(os.path.dirname(__file__), '../webassets'))
16+
active_ml_model = None
1617

1718
def fmt_msg(msg, http_code=400):
1819
return Response('{"error":"%s"}' % msg, content_type='application/json', status=http_code)
@@ -37,20 +38,23 @@ def html_formula(string):
3738
if sub: html_formula += '</sub>'
3839
return html_formula
3940

40-
@app_labs.route('/', methods=['GET'])
41-
def index():
42-
return send_from_directory(os.path.dirname(__file__), 'index.html')
43-
44-
@app_labs.route('/index.css', methods=['GET'])
45-
def style():
46-
return send_from_directory(os.path.dirname(__file__), 'index.css')
47-
48-
@app_labs.route('/player.html', methods=['GET'])
49-
def player():
50-
return send_from_directory(os.path.dirname(__file__), 'player.html')
41+
if SERVE_UI:
42+
@app_labs.route('/', methods=['GET'])
43+
def index():
44+
return send_from_directory(static_path, 'index.html')
45+
@app_labs.route('/index.css', methods=['GET'])
46+
def style():
47+
return send_from_directory(static_path, 'index.css')
48+
@app_labs.route('/player.html', methods=['GET'])
49+
def player():
50+
return send_from_directory(static_path, 'player.html')
51+
52+
@app_labs.after_request
53+
def add_cors_header(response):
54+
response.headers['Access-Control-Allow-Origin'] = '*'
55+
return response
5156

5257
@app_labs.route("/predict", methods=['POST'])
53-
@crossdomain(origin='*')
5458
def predict():
5559
if 'structure' not in request.values:
5660
return fmt_msg('Invalid request')
@@ -81,7 +85,7 @@ def predict():
8185
if error:
8286
return fmt_msg(error)
8387

84-
prediction, error = ase_to_ml_model(ase_obj, ml_model)
88+
prediction, error = ase_to_ml_model(ase_obj, active_ml_model)
8589
if error:
8690
return fmt_msg(error)
8791

@@ -105,10 +109,15 @@ def predict():
105109

106110
if __name__ == '__main__':
107111
if sys.argv[1:]:
108-
ml_model = load_ml_model(sys.argv[1:])
109-
print("Loaded models: " + " ".join(sys.argv[1:]))
112+
print("Models to load:\n" + "\n".join(sys.argv[1:]))
113+
active_ml_model = load_ml_model(sys.argv[1:])
114+
115+
elif ML_MODELS:
116+
print("Models to load:\n" + "\n".join(ML_MODELS))
117+
active_ml_model = load_ml_model(ML_MODELS)
118+
110119
else:
111-
print("No model loaded")
120+
print("No models to load")
112121

113122
app = Flask(__name__)
114123
app.debug = False

mpds_ml_labs/cif_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def cif_to_ase(cif_string):
4545
else:
4646
return None, 'Absent space group info in CIF'
4747

48+
disordered = False
4849
try:
4950
cellpar = (
5051
float( parsed_cif['_cell_length_a'][0].split('(')[0] ),
@@ -61,9 +62,13 @@ def cif_to_ase(cif_string):
6162
[ char.split('(')[0] for char in parsed_cif['_atom_site_fract_z'] ]
6263
]).astype(np.float)
6364
)
65+
disordered = any([float(occ) != 1 for occ in parsed_cif.get('_atom_site_occupancy', [])])
6466
except:
6567
return None, 'Unexpected non-numerical values occured in CIF'
6668

69+
if disordered:
70+
return None, 'This is disordered structure (not yet supported)'
71+
6772
symbols = parsed_cif.get('_atom_site_type_symbol')
6873
if not symbols:
6974
symbols = parsed_cif.get('_atom_site_label')

mpds_ml_labs/common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
import os
3+
from ConfigParser import ConfigParser
4+
5+
6+
DATA_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), '../data'))
7+
config = ConfigParser()
8+
config_path = os.path.join(DATA_PATH, 'settings.ini')
9+
10+
if os.path.exists(config_path):
11+
config.read(config_path)
12+
SERVE_UI = config.get('mpds_ml_labs', 'serve_ui')
13+
ML_MODELS = [path.strip() for path in filter(
14+
None,
15+
config.get('mpds_ml_labs', 'ml_models').split()
16+
)]
17+
else:
18+
SERVE_UI = True
19+
ML_MODELS = []

mpds_ml_labs/struct_utils.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
import re
3+
import fractions
34
import cStringIO
45

56
from ase.atoms import Atoms
@@ -87,30 +88,45 @@ def symmetrize(ase_obj, accuracy=1E-03):
8788
return None, 'Unrecognized sites or invalid site symmetry in structure'
8889

8990

90-
def get_formula(ase_obj):
91-
formula_sequence = ['Fr','Cs','Rb','K','Na','Li', 'Be','Mg','Ca','Sr','Ba','Ra', 'Sc','Y','La','Ce','Pr','Nd','Pm','Sm','Eu','Gd','Tb','Dy','Ho','Er','Tm','Yb', 'Ac','Th','Pa','U','Np','Pu', 'Ti','Zr','Hf', 'V','Nb','Ta', 'Cr','Mo','W', 'Fe','Ru','Os', 'Co','Rh','Ir', 'Mn','Tc','Re', 'Ni','Pd','Pt', 'Cu','Ag','Au', 'Zn','Cd','Hg', 'B','Al','Ga','In','Tl', 'Pb','Sn','Ge','Si','C', 'N','P','As','Sb','Bi', 'H', 'Po','Te','Se','S','O', 'At','I','Br','Cl','F', 'He','Ne','Ar','Kr','Xe','Rn']
91+
FORMULA_SEQUENCE = ['Fr','Cs','Rb','K','Na','Li', 'Be','Mg','Ca','Sr','Ba','Ra', 'Sc','Y','La','Ce','Pr','Nd','Pm','Sm','Eu','Gd','Tb','Dy','Ho','Er','Tm','Yb', 'Ac','Th','Pa','U','Np','Pu', 'Ti','Zr','Hf', 'V','Nb','Ta', 'Cr','Mo','W', 'Fe','Ru','Os', 'Co','Rh','Ir', 'Mn','Tc','Re', 'Ni','Pd','Pt', 'Cu','Ag','Au', 'Zn','Cd','Hg', 'B','Al','Ga','In','Tl', 'Pb','Sn','Ge','Si','C', 'N','P','As','Sb','Bi', 'H', 'Po','Te','Se','S','O', 'At','I','Br','Cl','F', 'He','Ne','Ar','Kr','Xe','Rn']
9292

93-
labels = {}
94-
types = []
95-
count = 0
93+
def get_formula(ase_obj, find_gcd=True):
94+
parsed_formula = {}
9695

97-
for k, label in enumerate(ase_obj.get_chemical_symbols()):
98-
if label not in labels:
99-
labels[label] = count
100-
types.append([k+1])
101-
count += 1
96+
for label in ase_obj.get_chemical_symbols():
97+
if label not in parsed_formula:
98+
parsed_formula[label] = 1
10299
else:
103-
types[ labels[label] ].append(k+1)
100+
parsed_formula[label] += 1
104101

105-
atoms = labels.keys()
106-
atoms = [x for x in formula_sequence if x in atoms] + [x for x in atoms if x not in formula_sequence]
102+
expanded = reduce(fractions.gcd, parsed_formula.values()) if find_gcd else 1
103+
if expanded > 1:
104+
parsed_formula = {el: int(content / float(expanded))
105+
for el, content in parsed_formula.items()}
106+
107+
atoms = parsed_formula.keys()
108+
atoms = [x for x in FORMULA_SEQUENCE if x in atoms] + [x for x in atoms if x not in FORMULA_SEQUENCE]
107109
formula = ''
108110
for atom in atoms:
109-
n = len(types[labels[atom]])
110-
if n == 1:
111-
n = ''
112-
else:
113-
n = str(n)
114-
formula += atom + n
111+
index = parsed_formula[atom]
112+
index = '' if index == 1 else str(index)
113+
formula += atom + index
115114

116115
return formula
116+
117+
118+
def sgn_to_crsystem(number):
119+
if 195 <= number <= 230:
120+
return 'cubic'
121+
elif 168 <= number <= 194:
122+
return 'hexagonal'
123+
elif 143 <= number <= 167:
124+
return 'trigonal'
125+
elif 75 <= number <= 142:
126+
return 'tetragonal'
127+
elif 16 <= number <= 74:
128+
return 'orthorhombic'
129+
elif 3 <= number <= 15:
130+
return 'monoclinic'
131+
else:
132+
return 'triclinic'

mpds_ml_labs/test_ml.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
2+
import os, sys
3+
4+
from struct_utils import detect_format, poscar_to_ase, symmetrize
5+
from cif_utils import cif_to_ase
6+
from prediction import ase_to_ml_model, load_ml_model, human_names
7+
from common import ML_MODELS, DATA_PATH
8+
9+
10+
models, structures = [], []
11+
12+
if sys.argv[1:]:
13+
inputs = [f for f in sys.argv[1:] if os.path.isfile(f)]
14+
models, structures = \
15+
[f for f in inputs if f.endswith('.pkl')], [f for f in inputs if not f.endswith('.pkl')]
16+
17+
if not models:
18+
models = ML_MODELS
19+
20+
if not structures:
21+
structures = [os.path.join(DATA_PATH, f) for f in os.listdir(DATA_PATH) if os.path.isfile(os.path.join(DATA_PATH, f))]
22+
23+
active_ml_model = load_ml_model(models)
24+
25+
for fname in structures:
26+
print
27+
print(fname)
28+
structure = open(fname).read()
29+
30+
fmt = detect_format(structure)
31+
32+
if fmt == 'cif':
33+
ase_obj, error = cif_to_ase(structure)
34+
if error:
35+
print(error)
36+
continue
37+
38+
elif fmt == 'poscar':
39+
ase_obj, error = poscar_to_ase(structure)
40+
if error:
41+
print(error)
42+
continue
43+
44+
else:
45+
print('Error: %s is not a crystal structure' % fname)
46+
continue
47+
48+
ase_obj, error = symmetrize(ase_obj)
49+
if error:
50+
print(error)
51+
continue
52+
53+
prediction, error = ase_to_ml_model(ase_obj, active_ml_model)
54+
if error:
55+
print(error)
56+
continue
57+
58+
for prop_id, pdata in prediction.items():
59+
print("{0:40} = {1:6} (MAE = {2:4}), {3}".format(
60+
human_names[prop_id]['name'],
61+
pdata['value'],
62+
pdata['mae'],
63+
human_names[prop_id]['units']
64+
))

0 commit comments

Comments
 (0)