Skip to content

Commit 37db3c9

Browse files
darrylongtqtg
andauthored
Enhance serving evaluation endpoints (#595)
* Include metric_user_results in evaluation response, added eval json endpoint * Remove query from response * Utilize mapped inversed user id map to get original id in response * Update serving test case to remove 'query' and add 'user_result' in response * simplify user ID mapping * Combined evaluation and evaluation_json endpoints * Updated abort responses to show plaintext instead of html * Added unit test cases * Updated error response for empty data * Added unit tests for provided data evaluation * Update app.py * Update test_app.py --------- Co-authored-by: Quoc-Tuan Truong <tqtg@users.noreply.github.com>
1 parent 92a94e3 commit 37db3c9

File tree

2 files changed

+105
-28
lines changed

2 files changed

+105
-28
lines changed

cornac/serving/app.py

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from cornac.metrics import *
2727

2828
try:
29-
from flask import Flask, jsonify, request
29+
from flask import Flask, jsonify, request, abort, make_response
3030
except ImportError:
3131
exit("Flask is required in order to serve models.\n" + "Run: pip3 install Flask")
3232

@@ -185,7 +185,6 @@ def add_feedback():
185185
return jsonify(data), 200
186186

187187

188-
# curl -X POST -H "Content-Type: application/json" -d '{"metrics": ["RMSE()", "NDCG(k=10)"]}' "http://localhost:8080/evaluate"
189188
@app.route("/evaluate", methods=["POST"])
190189
def evaluate():
191190
global model, train_set, metric_classnames
@@ -197,20 +196,59 @@ def evaluate():
197196
return "Unable to evaluate. 'train_set' is not provided", 400
198197

199198
query = request.json
199+
validate_query(query)
200200

201-
query_metrics = query.get("metrics")
202-
rating_threshold = query.get("rating_threshold", 1.0)
203201
exclude_unknowns = (
204202
query.get("exclude_unknowns", "true").lower() == "true"
205203
) # exclude unknown users/items by default, otherwise specified
204+
205+
if "data" in query:
206+
data = query.get("data")
207+
else:
208+
data = []
209+
data_fpath = "data/feedback.csv"
210+
if os.path.exists(data_fpath):
211+
reader = Reader()
212+
data = reader.read(data_fpath, fmt="UIR", sep=",")
213+
214+
if not data:
215+
response = make_response("No feedback has been provided so far. No data available to evaluate the model.")
216+
response.status_code = 400
217+
abort(response)
218+
219+
test_set = Dataset.build(
220+
data,
221+
fmt="UIR",
222+
global_uid_map=train_set.uid_map,
223+
global_iid_map=train_set.iid_map,
224+
exclude_unknowns=exclude_unknowns,
225+
)
226+
227+
return process_evaluation(test_set, query, exclude_unknowns)
228+
229+
230+
def validate_query(query):
231+
query_metrics = query.get("metrics")
232+
233+
if not query_metrics:
234+
response = make_response("metrics is required")
235+
response.status_code = 400
236+
abort(response)
237+
elif not isinstance(query_metrics, list):
238+
response = make_response("metrics must be an array of metrics")
239+
response.status_code = 400
240+
abort(response)
241+
242+
243+
def process_evaluation(test_set, query, exclude_unknowns):
244+
global model, train_set
245+
246+
rating_threshold = query.get("rating_threshold", 1.0)
206247
user_based = (
207248
query.get("user_based", "true").lower() == "true"
208249
) # user_based evaluation by default, otherwise specified
209250

210-
if query_metrics is None:
211-
return "metrics is required", 400
212-
elif not isinstance(query_metrics, list):
213-
return "metrics must be an array of metrics", 400
251+
query_metrics = query.get("metrics")
214252

215253
# organize metrics
216254
metrics = []
@@ -226,24 +264,6 @@ def evaluate():
226264

227265
rating_metrics, ranking_metrics = BaseMethod.organize_metrics(metrics)
228266

229-
# read data
230-
data = []
231-
data_fpath = "data/feedback.csv"
232-
if os.path.exists(data_fpath):
233-
reader = Reader()
234-
data = reader.read(data_fpath, fmt="UIR", sep=",")
235-
236-
if not len(data):
237-
raise ValueError("No data available to evaluate the model.")
238-
239-
test_set = Dataset.build(
240-
data,
241-
fmt="UIR",
242-
global_uid_map=train_set.uid_map,
243-
global_iid_map=train_set.iid_map,
244-
exclude_unknowns=exclude_unknowns,
245-
)
246-
247267
# evaluation
248268
result = BaseMethod.eval(
249269
model=model,
@@ -258,10 +278,17 @@ def evaluate():
258278
verbose=False,
259279
)
260280

281+
# map user index back into the original user ID
282+
metric_user_results = {}
283+
for metric, user_results in result.metric_user_results.items():
284+
metric_user_results[metric] = {
285+
train_set.user_ids[int(k)]: v for k, v in user_results.items()
286+
}
287+
261288
# response
262289
response = {
263290
"result": result.metric_avg_results,
264-
"query": query,
291+
"user_result": metric_user_results,
265292
}
266293

267294
return jsonify(response), 200

tests/cornac/serving/test_app.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ def test_evaluate_json(client):
9696
response = client.post('/evaluate', json=json_data)
9797
# assert response.content_type == 'application/json'
9898
assert response.status_code == 200
99-
assert len(response.json['query']['metrics']) == 2
10099
assert 'RMSE' in response.json['result']
101100
assert 'Recall@5' in response.json['result']
101+
assert 'RMSE' in response.json['user_result']
102+
assert 'Recall@5' in response.json['user_result']
102103

103104

104105
def test_evalulate_incorrect_get(client):
@@ -110,3 +111,52 @@ def test_evalulate_incorrect_post(client):
110111
response = client.post('/evaluate')
111112
assert response.status_code == 415 # bad request, expect json
112113

114+
115+
def test_evaluate_missing_metrics(client):
116+
json_data = {
117+
'metrics': []
118+
}
119+
response = client.post('/evaluate', json=json_data)
120+
assert response.status_code == 400
121+
assert response.data == b'metrics is required'
122+
123+
124+
def test_evaluate_not_list_metrics(client):
125+
json_data = {
126+
'metrics': 'RMSE()'
127+
}
128+
response = client.post('/evaluate', json=json_data)
129+
assert response.status_code == 400
130+
assert response.data == b'metrics must be an array of metrics'
131+
132+
133+
def test_recommend_missing_uid(client):
134+
response = client.get('/recommend?k=5')
135+
assert response.status_code == 400
136+
assert response.data == b'uid is required'
137+
138+
139+
def test_evaluate_use_data(client):
140+
json_data = {
141+
'metrics': ['RMSE()', 'Recall(k=5)'],
142+
'data': [['930', '795', 5], ['195', '795', 3]]
143+
}
144+
response = client.post('/evaluate', json=json_data)
145+
# assert response.content_type == 'application/json'
146+
assert response.status_code == 200
147+
assert 'RMSE' in response.json['result']
148+
assert 'Recall@5' in response.json['result']
149+
assert 'RMSE' in response.json['user_result']
150+
assert 'Recall@5' in response.json['user_result']
151+
152+
153+
def test_evaluate_use_data_empty(client):
154+
json_data = {
155+
'metrics': ['RMSE()', 'Recall(k=5)'],
156+
'data': []
157+
}
158+
response = client.post('/evaluate', json=json_data)
159+
assert response.status_code == 400
160+
assert response.data == b"No feedback has been provided so far. No data available to evaluate the model."
161+
162+

0 commit comments

Comments
 (0)