Skip to content

Commit 6c56eee

Browse files
dtzareedorenko
authored andcommitted
Generalize model tag finder, tag model experiment_name (#103)
1 parent 969e6ca commit 6c56eee

File tree

5 files changed

+69
-67
lines changed

5 files changed

+69
-67
lines changed

code/evaluate/evaluate_model.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,17 @@
2424
POSSIBILITY OF SUCH DAMAGE.
2525
"""
2626
import os
27-
from azureml.core import Model, Run, Workspace, Experiment
27+
import sys
28+
from azureml.core import Run, Workspace, Experiment
2829
import argparse
2930
from azureml.core.authentication import ServicePrincipalAuthentication
3031
import traceback
3132

3233
run = Run.get_context()
3334
if (run.id.startswith('OfflineRun')):
3435
from dotenv import load_dotenv
36+
sys.path.append(os.path.abspath("./code/util")) # NOQA: E402
37+
from model_helper import get_model_by_tag
3538
# For local development, set values in this section
3639
load_dotenv()
3740
workspace_name = os.environ.get("WORKSPACE_NAME")
@@ -56,8 +59,10 @@
5659
)
5760
ws = aml_workspace
5861
exp = Experiment(ws, experiment_name)
59-
run_id = "e78b2c27-5ceb-49d9-8e84-abe7aecf37d5"
62+
run_id = "57fee47f-5ae8-441c-bc0c-d4c371f32d70"
6063
else:
64+
sys.path.append(os.path.abspath("./util")) # NOQA: E402
65+
from model_helper import get_model_by_tag
6166
exp = run.experiment
6267
ws = run.experiment.workspace
6368
run_id = 'amlcompute'
@@ -94,16 +99,15 @@
9499
# Paramaterize the matrices on which the models should be compared
95100
# Add golden data set on which all the model performance can be evaluated
96101
try:
97-
model_list = Model.list(ws)
98-
if (len(model_list) > 0):
99-
production_model = next(
100-
filter(
101-
lambda x: x.created_time == max(
102-
model.created_time for model in model_list),
103-
model_list,
104-
)
105-
)
106-
production_model_run_id = production_model.run_id
102+
firstRegistration = False
103+
tag_name = 'experiment_name'
104+
105+
model = get_model_by_tag(
106+
model_name, tag_name, exp.name, ws)
107+
108+
if (model is not None):
109+
110+
production_model_run_id = model.run_id
107111

108112
# Get the run history for both production model and
109113
# newly trained model and compare mse
@@ -136,6 +140,7 @@
136140
else:
137141
print("This is the first model, "
138142
"thus it should be registered")
143+
139144
except Exception:
140145
traceback.print_exc(limit=None, file=None, chain=True)
141146
print("Something went wrong trying to evaluate. Exiting.")

code/register/register_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def main():
3838
if (run.id.startswith('OfflineRun')):
3939
from dotenv import load_dotenv
4040
sys.path.append(os.path.abspath("./code/util")) # NOQA: E402
41-
from model_helper import get_model_by_build_id
41+
from model_helper import get_model_by_tag
4242
# For local development, set values in this section
4343
load_dotenv()
4444
workspace_name = os.environ.get("WORKSPACE_NAME")
@@ -66,7 +66,7 @@ def main():
6666
run_id = "bd184a18-2ac8-4951-8e78-e290bef3b012"
6767
else:
6868
sys.path.append(os.path.abspath("./util")) # NOQA: E402
69-
from model_helper import get_model_by_build_id
69+
from model_helper import get_model_by_tag
7070
ws = run.experiment.workspace
7171
exp = run.experiment
7272
run_id = 'amlcompute'
@@ -108,7 +108,8 @@ def main():
108108

109109
if (validate):
110110
try:
111-
get_model_by_build_id(model_name, build_id, exp.workspace)
111+
tag_name = 'BuildId'
112+
get_model_by_tag(model_name, tag_name, build_id, exp.workspace)
112113
print("Model was registered for this build.")
113114
except Exception as e:
114115
print(e)
@@ -139,12 +140,13 @@ def register_aml_model(model_name, exp, run_id, build_id: str = 'none'):
139140
model_already_registered(model_name, exp, run_id)
140141
run = Run(experiment=exp, run_id=run_id)
141142
tagsValue = {"area": "diabetes", "type": "regression",
142-
"BuildId": build_id, "run_id": run_id}
143+
"BuildId": build_id, "run_id": run_id,
144+
"experiment_name": exp.name}
143145
else:
144146
run = Run(experiment=exp, run_id=run_id)
145147
if (run is not None):
146-
tagsValue = {"area": "diabetes",
147-
"type": "regression", "run_id": run_id}
148+
tagsValue = {"area": "diabetes", "type": "regression",
149+
"run_id": run_id, "experiment_name": exp.name}
148150
else:
149151
print("A model run for experiment", exp.name,
150152
"matching properties run_id =", run_id,

code/util/model_helper.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ def get_current_workspace() -> Workspace:
2222
return experiment.workspace
2323

2424

25-
def _get_model_by_build_id(
25+
def get_model_by_tag(
2626
model_name: str,
27-
build_id: str,
27+
tag_name: str,
28+
tag_value: str,
2829
aml_workspace: Workspace = None
2930
) -> AMLModel:
3031
"""
@@ -34,52 +35,46 @@ def _get_model_by_build_id(
3435
Parameters:
3536
aml_workspace (Workspace): aml.core Workspace that the model lives.
3637
model_name (str): name of the model we are looking for
37-
build_id (str): the build id the model was registered under.
38+
tag (str): the tag value the model was registered under.
3839
3940
Return:
4041
A single aml model from the workspace that matches the name and tag.
4142
"""
42-
# Validate params. cannot be None.
43-
if model_name is None:
44-
raise ValueError("model_name[:str] is required")
45-
if build_id is None:
46-
raise ValueError("build_id[:str] is required")
47-
if aml_workspace is None:
48-
aml_workspace = get_current_workspace()
43+
try:
44+
# Validate params. cannot be None.
45+
if model_name is None:
46+
raise ValueError("model_name[:str] is required")
47+
if tag_name is None:
48+
raise ValueError("tag_name[:str] is required")
49+
if tag_value is None:
50+
raise ValueError("tag[:str] is required")
51+
if aml_workspace is None:
52+
aml_workspace = get_current_workspace()
4953

50-
# get model by tag.
51-
model_list = AMLModel.list(
52-
aml_workspace, name=model_name,
53-
tags=[["BuildId", build_id]], latest=True
54-
)
54+
# get model by tag.
55+
model_list = AMLModel.list(
56+
aml_workspace, name=model_name,
57+
tags=[[tag_name, tag_value]], latest=True
58+
)
5559

56-
# latest should only return 1 model, but if it does, then maybe
57-
# internal code was accidentally changed or the source code has changed.
58-
should_not_happen = ("THIS SHOULD NOT HAPPEN: found more than one model "
59-
"for the latest with {{model_name: {model_name},"
60-
"BuildId: {build_id}. Models found: {model_list}}}")\
61-
.format(model_name=model_name, build_id=build_id,
62-
model_list=model_list)
63-
if len(model_list) > 1:
64-
raise ValueError(should_not_happen)
60+
# latest should only return 1 model, but if it does,
61+
# then maybe sdk or source code changed.
62+
should_not_happen = ("Found more than one model "
63+
"for the latest with {{tag_name: {tag_name},"
64+
"tag_value: {tag_value}. "
65+
"Models found: {model_list}}}")\
66+
.format(tag_name=tag_name, tag_value=tag_value,
67+
model_list=model_list)
68+
no_model_found = ("No Model found with {{tag_name: {tag_name} ,"
69+
"tag_value: {tag_value}.}}")\
70+
.format(tag_name=tag_name, tag_value=tag_value)
6571

66-
return model_list
67-
68-
69-
def get_model_by_build_id(
70-
model_name: str,
71-
build_id: str,
72-
aml_workspace: Workspace = None
73-
) -> AMLModel:
74-
"""
75-
Wrapper function for get_model_by_id that throws an error if model is none
76-
"""
77-
model_list = _get_model_by_build_id(model_name, build_id, aml_workspace)
78-
79-
if model_list:
80-
return model_list[0]
81-
82-
no_model_found = ("Model not found with model_name: {model_name} "
83-
"BuildId: {build_id}.")\
84-
.format(model_name=model_name, build_id=build_id)
85-
raise Exception(no_model_found)
72+
if len(model_list) > 1:
73+
raise ValueError(should_not_happen)
74+
if len(model_list) == 1:
75+
return model_list[0]
76+
else:
77+
print(no_model_found)
78+
return None
79+
except Exception:
80+
raise

ml_service/pipelines/run_train_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
def main():
1111
e = Env()
1212
service_principal = ServicePrincipalAuthentication(
13-
tenant_id=e.tenant_id,
14-
service_principal_id=e.app_id,
15-
service_principal_password=e.app_secret)
13+
tenant_id=e.tenant_id,
14+
service_principal_id=e.app_id,
15+
service_principal_password=e.app_secret)
1616

1717
aml_workspace = Workspace.get(
1818
name=e.workspace_name,

ml_service/util/env_variables.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class Singleton(object):
77

88
def __new__(class_, *args, **kwargs):
99
if class_ not in class_._instances:
10-
class_._instances[class_] = super(Singleton, class_).__new__(class_, *args, **kwargs) # noqa E501
10+
class_._instances[class_] = super(Singleton, class_).__new__(class_, *args, **kwargs) # noqa E501
1111
return class_._instances[class_]
1212

1313

@@ -23,7 +23,7 @@ def __init__(self):
2323
self._app_secret = os.environ.get("SP_APP_SECRET")
2424
self._vm_size = os.environ.get("AML_COMPUTE_CLUSTER_CPU_SKU")
2525
self._compute_name = os.environ.get("AML_COMPUTE_CLUSTER_NAME")
26-
self._vm_priority = os.environ.get("AML_CLUSTER_PRIORITY", 'lowpriority') # noqa E501
26+
self._vm_priority = os.environ.get("AML_CLUSTER_PRIORITY", 'lowpriority') # noqa E501
2727
self._min_nodes = int(os.environ.get("AML_CLUSTER_MIN_NODES", 0))
2828
self._max_nodes = int(os.environ.get("AML_CLUSTER_MAX_NODES", 4))
2929
self._build_id = os.environ.get("BUILD_BUILDID")

0 commit comments

Comments
 (0)