Skip to content

Commit 675bd24

Browse files
mjaowJun Min
andauthored
support mdc local mode (Azure#30357)
* update * add null check * add doc * initialize env var and volumes as empty dict * use dict update * use update --------- Co-authored-by: Jun Min <[email protected]>
1 parent a26a576 commit 675bd24

File tree

3 files changed

+208
-1
lines changed

3 files changed

+208
-1
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
import json
5+
import os.path
6+
from pathlib import Path
7+
8+
from azure.ai.ml.entities._deployment.data_collector import DataCollector
9+
10+
11+
class MdcConfigResolver(object):
12+
13+
"""Represents the contents of mdc config and handles writing the mdc configuration to User's system.
14+
15+
:param data_collector: model data collector entity
16+
:type data_collector: DataCollector
17+
"""
18+
19+
def __init__(
20+
self,
21+
data_collector: DataCollector,
22+
):
23+
self.environment_variables = {}
24+
self.volumes = {}
25+
self.mdc_config = None
26+
self.config_path = "/etc/mdc-config.json"
27+
self.local_config_name = "mdc-config.json"
28+
self._construct(data_collector)
29+
30+
def _construct(self, data_collector: DataCollector) -> None:
31+
"""Internal use only.
32+
33+
Constructs the mdc configuration based on entity.
34+
"""
35+
if not data_collector.collections:
36+
return
37+
38+
if len(data_collector.collections) <= 0:
39+
return
40+
41+
sampling_percentage = int(data_collector.sampling_rate * 100) if data_collector.sampling_rate else 100
42+
43+
self.mdc_config = {"collections": {}, "runMode": "local"}
44+
custom_logging_enabled = False
45+
for k, v in data_collector.collections.items():
46+
if v.enabled and v.enabled.lower() == "true":
47+
lower_k = k.lower()
48+
49+
if lower_k not in ("request", "response"):
50+
custom_logging_enabled = True
51+
52+
self.mdc_config["collections"][lower_k] = {
53+
"enabled": True,
54+
"sampling_percentage": int(v.sampling_rate * 100) if v.sampling_rate else sampling_percentage,
55+
}
56+
57+
if not custom_logging_enabled:
58+
self.mdc_config = None
59+
return
60+
61+
if data_collector.request_logging and data_collector.request_logging.capture_headers:
62+
self.mdc_config["captureHeaders"] = data_collector.request_logging.capture_headers
63+
64+
def write_file(self, directory_path: str) -> None:
65+
"""Writes this mdc configuration to a file in provided directory.
66+
67+
:param directory_path: absolute path of local directory to write Dockerfile.
68+
:type directory_path: str
69+
"""
70+
if not self.mdc_config:
71+
return
72+
73+
mdc_setting_path = str(Path(directory_path, self.local_config_name).resolve())
74+
with open(mdc_setting_path, "w") as f:
75+
d = json.dumps(self.mdc_config)
76+
f.write(f"{d}")
77+
78+
self.environment_variables = {"AZUREML_MDC_CONFIG_PATH": self.config_path}
79+
local_path = os.path.join(directory_path, self.local_config_name)
80+
81+
self.volumes = {f"{local_path}:{self.config_path}:z": {local_path: {"bind": self.config_path}}}

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_local_deployment_helper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
get_deployment_json_from_container,
2020
get_status_from_container,
2121
)
22+
from azure.ai.ml._local_endpoints.mdc_config_resolver import MdcConfigResolver
2223
from azure.ai.ml._local_endpoints.validators.code_validator import get_code_configuration_artifacts
2324
from azure.ai.ml._local_endpoints.validators.environment_validator import get_environment_artifacts
2425
from azure.ai.ml._local_endpoints.validators.model_validator import get_model_artifacts
@@ -259,6 +260,17 @@ def _create_deployment(
259260
**image_context.environment,
260261
**user_environment_variables,
261262
}
263+
264+
volumes = {}
265+
volumes.update(image_context.volumes)
266+
267+
if deployment.data_collector:
268+
mdc_config = MdcConfigResolver(deployment.data_collector)
269+
mdc_config.write_file(deployment_directory_path)
270+
271+
environment_variables.update(mdc_config.environment_variables)
272+
volumes.update(mdc_config.volumes)
273+
262274
# Determine whether we need to use local context or downloaded context
263275
build_directory = downloaded_build_context if downloaded_build_context else deployment_directory
264276
self._docker_client.create_deployment(
@@ -270,7 +282,7 @@ def _create_deployment(
270282
dockerfile_path=None if is_byoc else dockerfile.local_path,
271283
conda_source_path=yaml_env_conda_file_path,
272284
conda_yaml_contents=yaml_env_conda_file_contents,
273-
volumes=image_context.volumes,
285+
volumes=volumes,
274286
environment=environment_variables,
275287
azureml_port=inference_config.scoring_route.port if is_byoc else LocalEndpointConstants.DOCKER_PORT,
276288
local_endpoint_mode=local_endpoint_mode,
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
import pytest
6+
7+
from azure.ai.ml._local_endpoints.mdc_config_resolver import MdcConfigResolver
8+
from azure.ai.ml.entities._deployment.data_collector import DataCollector
9+
from azure.ai.ml.entities._deployment.deployment_collection import DeploymentCollection
10+
from azure.ai.ml.entities._deployment.request_logging import RequestLogging
11+
12+
13+
@pytest.mark.unittest
14+
class TestMdcConfigResolver:
15+
def test_resolve_mdc_config(self):
16+
resolver = MdcConfigResolver(
17+
data_collector=DataCollector(
18+
collections={
19+
"inputs": DeploymentCollection(enabled="true", sampling_rate=0.8),
20+
"outputs": DeploymentCollection(enabled="true", sampling_rate=0.7),
21+
"request": DeploymentCollection(enabled="true", sampling_rate=0.6),
22+
"Response": DeploymentCollection(enabled="true", sampling_rate=0.5),
23+
},
24+
request_logging=RequestLogging(capture_headers=["aaa", "bbb"]),
25+
)
26+
)
27+
28+
mdc_config = {
29+
"collections": {
30+
"inputs": {"enabled": True, "sampling_percentage": 80},
31+
"outputs": {"enabled": True, "sampling_percentage": 70},
32+
"request": {"enabled": True, "sampling_percentage": 60},
33+
"response": {"enabled": True, "sampling_percentage": 50},
34+
},
35+
"runMode": "local",
36+
"captureHeaders": ["aaa", "bbb"],
37+
}
38+
39+
assert mdc_config == resolver.mdc_config
40+
assert {} == resolver.environment_variables
41+
assert {} == resolver.volumes
42+
43+
def test_resolve_mdc_config_global_sampling_rate(self):
44+
resolver = MdcConfigResolver(
45+
data_collector=DataCollector(
46+
collections={
47+
"inputs": DeploymentCollection(enabled="true"),
48+
"outputs": DeploymentCollection(enabled="true"),
49+
"request": DeploymentCollection(enabled="true"),
50+
"Response": DeploymentCollection(enabled="true"),
51+
},
52+
request_logging=RequestLogging(capture_headers=["aaa", "bbb"]),
53+
sampling_rate=0.9,
54+
)
55+
)
56+
57+
mdc_config = {
58+
"collections": {
59+
"inputs": {"enabled": True, "sampling_percentage": 90},
60+
"outputs": {"enabled": True, "sampling_percentage": 90},
61+
"request": {"enabled": True, "sampling_percentage": 90},
62+
"response": {"enabled": True, "sampling_percentage": 90},
63+
},
64+
"runMode": "local",
65+
"captureHeaders": ["aaa", "bbb"],
66+
}
67+
68+
assert mdc_config == resolver.mdc_config
69+
assert {} == resolver.environment_variables
70+
assert {} == resolver.volumes
71+
72+
def test_resolve_mdc_config_collections_disabled(self):
73+
resolver = MdcConfigResolver(
74+
data_collector=DataCollector(
75+
collections={
76+
"inputs": DeploymentCollection(enabled="false"),
77+
"outputs": DeploymentCollection(enabled="false"),
78+
"request": DeploymentCollection(enabled="false"),
79+
"Response": DeploymentCollection(enabled="false"),
80+
},
81+
request_logging=RequestLogging(capture_headers=["aaa", "bbb"]),
82+
sampling_rate=0.9,
83+
)
84+
)
85+
86+
assert not resolver.mdc_config
87+
resolver.write_file("/mnt")
88+
assert {} == resolver.environment_variables
89+
assert {} == resolver.volumes
90+
91+
def test_resolve_mdc_config_no_collections(self):
92+
resolver = MdcConfigResolver(data_collector=DataCollector(collections={}))
93+
94+
assert not resolver.mdc_config
95+
resolver.write_file("/mnt")
96+
assert {} == resolver.environment_variables
97+
assert {} == resolver.volumes
98+
99+
def test_resolve_mdc_config_no_custom_logging(self):
100+
resolver = MdcConfigResolver(
101+
data_collector=DataCollector(
102+
collections={
103+
"request": DeploymentCollection(enabled="true"),
104+
"Response": DeploymentCollection(enabled="true"),
105+
},
106+
request_logging=RequestLogging(capture_headers=["aaa", "bbb"]),
107+
sampling_rate=0.9,
108+
)
109+
)
110+
111+
assert not resolver.mdc_config
112+
resolver.write_file("/mnt")
113+
assert {} == resolver.environment_variables
114+
assert {} == resolver.volumes

0 commit comments

Comments
 (0)