forked from liupeirong/MLOpsManufacturing
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_training_pipeline.py
More file actions
120 lines (106 loc) · 4.22 KB
/
build_training_pipeline.py
File metadata and controls
120 lines (106 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from azureml.pipeline.core.graph import PipelineParameter
from azureml.pipeline.steps import PythonScriptStep
from azureml.pipeline.core import Pipeline, PipelineData
from azureml.core import Workspace
from azureml.core.runconfig import RunConfiguration
from ml_service.util.attach_compute import get_compute
from ml_service.util.env_variables import Env
from ml_service.util.manage_environment import get_environment
def main():
e = Env()
# Get Azure machine learning workspace
aml_workspace = Workspace.get(
name=e.workspace_name,
subscription_id=e.subscription_id,
resource_group=e.resource_group,
)
print(f"get_workspace:{aml_workspace}")
# Get Azure machine learning cluster
aml_compute = get_compute(aml_workspace, e.compute_name, e.vm_size)
if aml_compute is not None:
print(f"aml_compute:{aml_compute}")
# Create a reusable Azure ML environment
environment = get_environment(
aml_workspace,
e.aml_env_name,
conda_dependencies_file=e.aml_env_train_conda_dep_file,
create_new=e.rebuild_env,
) #
run_config = RunConfiguration()
run_config.environment = environment
if e.datastore_name:
datastore_name = e.datastore_name
else:
datastore_name = aml_workspace.get_default_datastore().name
run_config.environment.environment_variables["DATASTORE_NAME"] = datastore_name # NOQA: E501
# datastore and dataset names are fixed for this pipeline, however
# data_file_path can be specified for registering new versions of dataset
# Note that AML pipeline parameters don't take empty string as default, "" won't work # NOQA: E501
model_name_param = PipelineParameter(name="model_name", default_value=e.model_name) # NOQA: E501
data_file_path_param = PipelineParameter(name="data_file_path", default_value="nopath") # NOQA: E501
ml_params = PipelineParameter(name="ml_params", default_value="default") # NOQA: E501
# Create a PipelineData to pass data between steps
pipeline_data = PipelineData(
"pipeline_data", datastore=aml_workspace.get_default_datastore()
)
train_step = PythonScriptStep(
name="Train Model",
script_name="train/train_aml.py",
compute_target=aml_compute,
source_directory=e.sources_directory_train,
outputs=[pipeline_data],
arguments=[
"--model_name", model_name_param,
"--step_output", pipeline_data,
"--data_file_path", data_file_path_param,
"--dataset_name", e.processed_dataset_name,
"--datastore_name", datastore_name,
"--ml_params", ml_params,
],
runconfig=run_config,
allow_reuse=True,
)
print("Step Train created")
evaluate_step = PythonScriptStep(
name="Evaluate Model ",
script_name="evaluate/evaluate_model.py",
compute_target=aml_compute,
source_directory=e.sources_directory_train,
arguments=[
"--model_name", model_name_param,
"--ml_params", ml_params,
],
runconfig=run_config,
allow_reuse=False,
)
print("Step Evaluate created")
register_step = PythonScriptStep(
name="Register Model ",
script_name="register/register_model.py",
compute_target=aml_compute,
source_directory=e.sources_directory_train,
inputs=[pipeline_data],
arguments=[
"--model_name", model_name_param,
"--step_input", pipeline_data,
"--ml_params", ml_params,
],
runconfig=run_config,
allow_reuse=False,
)
print("Step Register created")
evaluate_step.run_after(train_step)
register_step.run_after(evaluate_step)
steps = [train_step, evaluate_step, register_step]
train_pipeline = Pipeline(workspace=aml_workspace, steps=steps)
train_pipeline._set_experiment_name
train_pipeline.validate()
published_pipeline = train_pipeline.publish(
name=e.training_pipeline_name,
description="Model training/retraining pipeline",
version=e.build_id,
)
print(f"Published pipeline: {published_pipeline.name}")
print(f"for build {published_pipeline.version}")
if __name__ == "__main__":
main()