Skip to content

Commit ac12f80

Browse files
authored
fix: avoid OSError in LocalPathField deserialization (Azure#28804)
1 parent a534418 commit ac12f80

File tree

3 files changed

+34
-20
lines changed

3 files changed

+34
-20
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,20 @@ def _resolve_path(self, value) -> Path:
131131
"""Resolve path to absolute path based on base_path in context.
132132
Will resolve the path if it's already an absolute path.
133133
"""
134-
result = Path(value)
135-
base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
136-
if not result.is_absolute():
137-
result = base_path / result
138134
try:
139-
return result.resolve()
135+
result = Path(value)
136+
base_path = Path(self.context[BASE_PATH_CONTEXT_KEY])
137+
if not result.is_absolute():
138+
result = base_path / result
139+
140+
# for non-path string like "azureml:/xxx", OSError can be raised in either
141+
# resolve() or is_dir() or is_file()
142+
result = result.resolve()
143+
if (self._allow_dir and result.is_dir()) or (self._allow_file and result.is_file()):
144+
return result
140145
except OSError:
141146
raise self.make_error("invalid_path")
147+
raise self.make_error("path_not_exist", path=result.as_posix(), allow_type=self.allowed_path_type)
142148

143149
@property
144150
def allowed_path_type(self) -> str:
@@ -154,16 +160,12 @@ def _validate(self, value):
154160

155161
if value is None:
156162
return
157-
path = self._resolve_path(value)
158-
if (self._allow_dir and path.is_dir()) or (self._allow_file and path.is_file()):
159-
return
160-
raise self.make_error("path_not_exist", path=path.as_posix(), allow_type=self.allowed_path_type)
163+
self._resolve_path(value)
161164

162165
def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
163166
# do not block serializing None even if required or not allow_none.
164167
if value is None:
165168
return None
166-
self._validate(value)
167169
# always dump path as absolute path in string as base_path will be dropped after serialization
168170
return super(LocalPathField, self)._serialize(self._resolve_path(value).as_posix(), attr, obj, **kwargs)
169171

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_online_deployment_operations.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def begin_create_or_update(
162162
operation_config=self._operation_config,
163163
)
164164
if deployment.data_collector:
165-
self._register_collection_data_assets(deployment= deployment)
165+
self._register_collection_data_assets(deployment=deployment)
166166

167167
upload_dependencies(deployment, orchestrators)
168168
try:
@@ -351,16 +351,13 @@ def _register_collection_data_assets(self, deployment: OnlineDeployment) -> None
351351
for collection in deployment.data_collector.collections:
352352
data_name = deployment.endpoint_name + "-" + deployment.name + "-" + collection
353353
data_object = Data(
354-
name = data_name,
355-
path = deployment.data_collector.destination.path
354+
name=data_name,
355+
path=deployment.data_collector.destination.path
356356
if deployment.data_collector.destination and deployment.data_collector.destination.path
357357
else DEFAULT_MDC_PATH,
358-
is_anonymous= True
359-
)
358+
is_anonymous=True,
359+
)
360360
result = self._all_operations._all_operations[AzureMLResourceType.DATA].create_or_update(data_object)
361361
deployment.data_collector.collections[collection].data = DataAsset(
362-
data_id = result.id,
363-
path = result.path,
364-
name = result.name,
365-
version = result.version
366-
)
362+
data_id=result.id, path=result.path, name=result.name, version=result.version
363+
)

sdk/ml/azure-ai-ml/tests/component/unittests/test_component_schema.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,21 @@ def test_dump_with_non_existent_base_path(self):
328328
component_entity._base_path = "/non/existent/path"
329329
component_entity._to_dict()
330330

331+
def test_arm_code_from_rest_object(self):
332+
arm_code = (
333+
"azureml:/subscriptions/xxx/resourceGroups/xxx/providers/Microsoft.MachineLearningServices/"
334+
"workspaces/zzz/codes/90b33c11-365d-4ee4-aaa1-224a042deb41/versions/1"
335+
)
336+
yaml_path = "./tests/test_configs/components/helloworld_component.yml"
337+
yaml_component = load_component(yaml_path)
338+
339+
from azure.ai.ml.entities import Component
340+
341+
rest_object = yaml_component._to_rest_object()
342+
rest_object.properties.component_spec["code"] = arm_code
343+
component = Component._from_rest_object(rest_object)
344+
assert component.code == arm_code[8:]
345+
331346

332347
@pytest.mark.timeout(_COMPONENT_TIMEOUT_SECOND)
333348
@pytest.mark.unittest

0 commit comments

Comments
 (0)