Skip to content

Commit ac9dcd1

Browse files
committed
add sagemaker pipelines
1 parent 5b0490a commit ac9dcd1

File tree

5 files changed

+973
-0
lines changed

5 files changed

+973
-0
lines changed

sagemaker_pipelines/paddleocr/__init__.py

Whitespace-only changes.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Evaluation script for measuring mean squared error."""
2+
import json
3+
import logging
4+
import pathlib
5+
import pickle
6+
import tarfile
7+
8+
import numpy as np
9+
import pandas as pd
10+
import xgboost
11+
12+
from sklearn.metrics import mean_squared_error
13+
14+
logger = logging.getLogger()
15+
logger.setLevel(logging.INFO)
16+
logger.addHandler(logging.StreamHandler())
17+
18+
19+
20+
if __name__ == "__main__":
21+
logger.debug("Starting evaluation.")
22+
model_path = "/opt/ml/processing/model/model.tar.gz"
23+
with tarfile.open(model_path) as tar:
24+
tar.extractall(path=".")
25+
26+
logger.debug("Loading xgboost model.")
27+
model = pickle.load(open("xgboost-model", "rb"))
28+
29+
logger.debug("Reading test data.")
30+
test_path = "/opt/ml/processing/test/test.csv"
31+
df = pd.read_csv(test_path, header=None)
32+
33+
logger.debug("Reading test data.")
34+
y_test = df.iloc[:, 0].to_numpy()
35+
df.drop(df.columns[0], axis=1, inplace=True)
36+
X_test = xgboost.DMatrix(df.values)
37+
38+
logger.info("Performing predictions against test data.")
39+
predictions = model.predict(X_test)
40+
41+
logger.debug("Calculating mean squared error.")
42+
mse = mean_squared_error(y_test, predictions)
43+
std = np.std(y_test - predictions)
44+
report_dict = {
45+
"regression_metrics": {
46+
"mse": {
47+
"value": mse,
48+
"standard_deviation": std
49+
},
50+
},
51+
}
52+
53+
output_dir = "/opt/ml/processing/evaluation"
54+
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
55+
56+
logger.info("Writing out evaluation report with mse: %f", mse)
57+
evaluation_path = f"{output_dir}/evaluation.json"
58+
with open(evaluation_path, "w") as f:
59+
f.write(json.dumps(report_dict))

0 commit comments

Comments
 (0)