Skip to content
This repository was archived by the owner on Jul 20, 2025. It is now read-only.

Commit 47a302e

Browse files
committed
Re-organize examples as test_TASK_TYPE.py
1 parent 7e916a9 commit 47a302e

File tree

5 files changed

+59
-22
lines changed

5 files changed

+59
-22
lines changed

mpds_ml_labs/common.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11

22
import os
33
from ConfigParser import ConfigParser
4+
from urllib import urlencode
5+
6+
import ujson as json
7+
import pg8000
48

59

610
DATA_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), '../data'))
@@ -33,7 +37,6 @@
3337

3438

3539
def connect_database():
36-
import pg8000
3740

3841
assert KNN_TABLE
3942

@@ -47,3 +50,17 @@ def connect_database():
4750
cursor = connection.cursor()
4851

4952
return cursor, connection
53+
54+
55+
def make_request(req, address, data={}, httpverb='POST', headers={}):
56+
57+
address += '?' + urlencode(data)
58+
59+
if httpverb == 'GET':
60+
response, content = req.request(address, httpverb, headers=headers)
61+
62+
else:
63+
headers.update({'Content-type': 'application/x-www-form-urlencoded'})
64+
response, content = req.request(address, httpverb, headers=headers, body=urlencode(data))
65+
66+
return json.loads(content)

mpds_ml_labs/test_design_client.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
import time
3+
import httplib2
4+
import ujson as json
5+
6+
from common import make_request
7+
8+
9+
remote = httplib2.Http()
10+
11+
LABS_SERVER_ADDR = 'http://127.0.0.1:5000/design'
12+
13+
# NB. mind prediction_ranges.prediction_ranges
14+
sample = {
15+
'z': [200, 265],
16+
'y': [-325, -250],
17+
'x': [11, 28],
18+
'k': [150, 225],
19+
'w': [1, 3],
20+
'm': [2000, 2700],
21+
'd': [175, 1100],
22+
't': [-0.5, 3]
23+
}
24+
25+
if __name__ == '__main__':
26+
starttime = time.time()
27+
28+
answer = make_request(remote, LABS_SERVER_ADDR, {'numerics': json.dumps(sample)})
29+
if 'error' in answer:
30+
raise RuntimeError(answer['error'])
31+
32+
print(answer['vis_cif'])
33+
print("Done in %1.2f sc" % (time.time() - starttime))
File renamed without changes.

mpds_ml_labs/test_app.py renamed to mpds_ml_labs/test_props_client.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,24 @@
11

22
import sys
33
import time
4-
from urllib import urlencode
54

65
import httplib2
7-
import ujson as json
86
import numpy as np
9-
107
from mpds_client import MPDSDataRetrieval, APIError
118

129
from prediction import prop_models
1310
from struct_utils import detect_format, poscar_to_ase, refine, get_formula, sgn_to_crsystem
1411
from cif_utils import cif_to_ase
15-
from common import API_KEY, API_ENDPOINT
12+
from common import API_KEY, API_ENDPOINT, make_request
1613

1714

18-
req = httplib2.Http()
15+
remote = httplib2.Http()
1916
client = MPDSDataRetrieval(api_key=API_KEY, endpoint=API_ENDPOINT)
20-
ARITY = {1: 'unary', 2: 'binary', 3: 'ternary', 4: 'quaternary', 5: 'quinary'}
21-
22-
def make_request(address, data={}, httpverb='POST', headers={}):
2317

24-
address += '?' + urlencode(data)
18+
LABS_SERVER_ADDR = 'http://127.0.0.1:5000/predict'
2519

26-
if httpverb == 'GET':
27-
response, content = req.request(address, httpverb, headers=headers)
28-
29-
else:
30-
headers.update({'Content-type': 'application/x-www-form-urlencoded'})
31-
response, content = req.request(address, httpverb, headers=headers, body=urlencode(data))
20+
ARITY = {1: 'unary', 2: 'binary', 3: 'ternary', 4: 'quaternary', 5: 'quinary'}
3221

33-
return json.loads(content)
3422

3523
if __name__ == '__main__':
3624

@@ -75,7 +63,7 @@ def make_request(address, data={}, httpverb='POST', headers={}):
7563
'lattices': sgn_to_crsystem(ase_obj.info['spacegroup'].no)
7664
}
7765

78-
answer = make_request('http://127.0.0.1:5000/predict', {'structure': structure})
66+
answer = make_request(remote, LABS_SERVER_ADDR, {'structure': structure})
7967
if 'error' in answer:
8068
raise RuntimeError(answer['error'])
8169

@@ -85,10 +73,9 @@ def make_request(address, data={}, httpverb='POST', headers={}):
8573
resp = client.get_dataframe(tpl_query)
8674
except APIError as e:
8775
prop_models[prop_id]['factual'] = None
88-
if e.code == 1:
89-
continue
90-
else:
91-
raise
76+
if e.code != 1:
77+
print("While checking against the MPDS an error %s occured" % e.code)
78+
continue
9279

9380
resp['Value'] = resp['Value'].astype('float64') # to treat values out of bounds given as str
9481
resp = resp[resp['Units'] == pdata['units']]
File renamed without changes.

0 commit comments

Comments
 (0)