Skip to content

Commit 45afe29

Browse files
algattikeedorenko
authored andcommitted
Discover scoring model + Add Swagger endpoint to scoring container (#149)
1 parent 65c0793 commit 45afe29

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

diabetes_regression/scoring/conda_dependencies.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ dependencies:
3434
- joblib==0.14.0
3535
- gunicorn==19.9.0
3636
- flask==1.1.1
37-
37+
- inference-schema[numpy-support]

diabetes_regression/scoring/score.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,46 @@
2323
ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE
2424
POSSIBILITY OF SUCH DAMAGE.
2525
"""
26-
import json
2726
import numpy
28-
from azureml.core.model import Model
2927
import joblib
28+
import os
29+
from inference_schema.schema_decorators \
30+
import input_schema, output_schema
31+
from inference_schema.parameter_types.numpy_parameter_type \
32+
import NumpyParameterType
3033

3134

3235
def init():
36+
# load the model from file into a global object
3337
global model
3438

35-
# load the model from file into a global object
36-
model_path = Model.get_model_path(
37-
model_name="sklearn_regression_model.pkl")
39+
# AZUREML_MODEL_DIR is an environment variable created during service
40+
# deployment. It contains the path to the folder containing the model.
41+
path = os.environ['AZUREML_MODEL_DIR']
42+
model_path = None
43+
for root, dirs, files in os.walk(path):
44+
for file in files:
45+
if '.pkl' in file:
46+
model_path = os.path.join(path, file)
47+
if model_path is None:
48+
raise ValueError(".pkl model not found")
3849
model = joblib.load(model_path)
3950

4051

41-
def run(raw_data, request_headers):
42-
data = json.loads(raw_data)["data"]
43-
data = numpy.array(data)
52+
input_sample = numpy.array([
53+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
54+
[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]])
55+
output_sample = numpy.array([
56+
5021.509689995557,
57+
3693.645386402646])
58+
59+
60+
# Inference_schema generates a schema for your web service
61+
# It then creates an OpenAPI (Swagger) specification for the web service
62+
# at http://<scoring_base_url>/swagger.json
63+
@input_schema('data', NumpyParameterType(input_sample))
64+
@output_schema(NumpyParameterType(output_sample))
65+
def run(data, request_headers):
4466
result = model.predict(data)
4567

4668
# Demonstrate how we can log custom data into the Application Insights

0 commit comments

Comments
 (0)