Skip to content

Commit 370ecaa

Browse files
Type Fix for dsl folder (#33644)
* fix dsl _utils.py * fix pipeline_decorator * fix misc * fix pipeline_component_builder * fix misc - 2 * update dynamic * fix misc - 3 * fix misc - 4 * add type ignore and bug item number * update pyproject.toml - 1 * remove cast * unexclude dsl folder
1 parent 8f9c6b8 commit 370ecaa

15 files changed

+171
-130
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
5+
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
66

77
from azure.ai.ml.dsl._pipeline_decorator import pipeline
88

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_component_func.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
# pylint: disable=protected-access
66

7-
from typing import Callable, Mapping
7+
from typing import Any, Callable, List, Mapping
88

99
from azure.ai.ml.dsl._dynamic import KwParameter, create_kw_function_from_parameters
1010
from azure.ai.ml.entities import Component as ComponentEntity
1111
from azure.ai.ml.entities._builders import Command
1212
from azure.ai.ml.entities._component.datatransfer_component import DataTransferImportComponent
1313

1414

15-
def get_dynamic_input_parameter(inputs: Mapping):
15+
def get_dynamic_input_parameter(inputs: Mapping) -> List:
1616
"""Return the dynamic parameter of the definition's input ports.
1717
1818
:param inputs: The mapping of input names to input objects.
@@ -31,7 +31,7 @@ def get_dynamic_input_parameter(inputs: Mapping):
3131
]
3232

3333

34-
def get_dynamic_source_parameter(source):
34+
def get_dynamic_source_parameter(source: Any) -> List:
3535
"""Return the dynamic parameter of the definition's source port.
3636
3737
:param source: The source object.
@@ -49,7 +49,7 @@ def get_dynamic_source_parameter(source):
4949
]
5050

5151

52-
def to_component_func(entity: ComponentEntity, component_creation_func) -> Callable[..., Command]:
52+
def to_component_func(entity: ComponentEntity, component_creation_func: Callable) -> Callable[..., Command]:
5353
"""Convert a ComponentEntity to a callable component function.
5454
5555
:param entity: The ComponentEntity to convert.
@@ -97,6 +97,7 @@ def to_component_func(entity: ComponentEntity, component_creation_func) -> Calla
9797
flattened_group_keys=flattened_group_keys,
9898
)
9999

100-
dynamic_func._func_calling_example = example
101-
dynamic_func._has_parameters = bool(all_params)
100+
# Bug Item number: 2883188
101+
dynamic_func._func_calling_example = example # type: ignore
102+
dynamic_func._has_parameters = bool(all_params) # type: ignore
102103
return dynamic_func

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_do_while.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4+
from typing import Dict, Optional, Union
5+
6+
from azure.ai.ml.entities._builders import Command
47
from azure.ai.ml.entities._builders.do_while import DoWhile
8+
from azure.ai.ml.entities._builders.pipeline import Pipeline
9+
from azure.ai.ml.entities._inputs_outputs import Output
510
from azure.ai.ml.entities._job.pipeline._io import NodeOutput
611

712

8-
def do_while(body, mapping, max_iteration_count: int, condition=None):
13+
def do_while(
14+
body: Union[Pipeline, Command], mapping: Dict, max_iteration_count: int, condition: Optional[Output] = None
15+
) -> DoWhile:
916
"""Build a do_while node by specifying the loop body, output-input mapping, and termination condition.
1017
1118
.. remarks::
@@ -63,7 +70,7 @@ def pipeline_with_do_while_node():
6370
)
6471
do_while_node.set_limits(max_iteration_count=max_iteration_count)
6572

66-
def _infer_and_update_body_input_from_mapping():
73+
def _infer_and_update_body_input_from_mapping() -> None:
6774
# pylint: disable=protected-access
6875
for source_output, body_input in mapping.items():
6976
# handle case that mapping key is a NodeOutput

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_dynamic.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import types
66
from inspect import Parameter, Signature
7-
from typing import Callable, Dict, Sequence
7+
from typing import Any, Callable, Dict, Sequence, cast
88

99
from azure.ai.ml.entities import Component
1010
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, UnexpectedKeywordError, ValidationException
@@ -26,7 +26,9 @@ class KwParameter(Parameter):
2626
:type _optional: bool
2727
"""
2828

29-
def __init__(self, name, default, annotation=Parameter.empty, _type="str", _optional=False) -> None:
29+
def __init__(
30+
self, name: str, default: Any, annotation: Any = Parameter.empty, _type: str = "str", _optional: bool = False
31+
) -> None:
3032
super().__init__(name, Parameter.KEYWORD_ONLY, default=default, annotation=annotation)
3133
self._type = _type
3234
self._optional = _optional
@@ -54,23 +56,25 @@ def _replace_function_name(func: types.FunctionType, new_name: str) -> types.Fun
5456
else:
5557
# Before python<3.8, replace is not available, we can only initialize the code as following.
5658
# https://github.com/python/cpython/blob/v3.7.8/Objects/codeobject.c#L97
57-
code = types.CodeType(
59+
60+
# Bug Item number: 2881688
61+
code = types.CodeType( # type: ignore
5862
code_template.co_argcount,
5963
code_template.co_kwonlyargcount,
6064
code_template.co_nlocals,
6165
code_template.co_stacksize,
6266
code_template.co_flags,
63-
code_template.co_code,
64-
code_template.co_consts,
67+
code_template.co_code, # type: ignore
68+
code_template.co_consts, # type: ignore
6569
code_template.co_names,
6670
code_template.co_varnames,
67-
code_template.co_filename,
71+
code_template.co_filename, # type: ignore
6872
new_name, # Use the new name for the new code object.
69-
code_template.co_firstlineno,
70-
code_template.co_lnotab,
73+
code_template.co_firstlineno, # type: ignore
74+
code_template.co_lnotab, # type: ignore
7175
# The following two values are required for closures.
72-
code_template.co_freevars,
73-
code_template.co_cellvars,
76+
code_template.co_freevars, # type: ignore
77+
code_template.co_cellvars, # type: ignore
7478
)
7579
# Initialize a new function with the code object and the new name, see the following ref for more details.
7680
# https://github.com/python/cpython/blob/4901fe274bc82b95dc89bcb3de8802a3dfedab32/Objects/clinic/funcobject.c.h#L30
@@ -89,7 +93,7 @@ def _replace_function_name(func: types.FunctionType, new_name: str) -> types.Fun
8993

9094

9195
# pylint: disable-next=docstring-missing-param
92-
def _assert_arg_valid(kwargs: dict, keys: list, func_name: str):
96+
def _assert_arg_valid(kwargs: dict, keys: list, func_name: str) -> None:
9397
"""Assert the arg keys are all in keys."""
9498
# pylint: disable=protected-access
9599
# validate component input names
@@ -114,7 +118,7 @@ def _assert_arg_valid(kwargs: dict, keys: list, func_name: str):
114118
kwargs[lower2original_parameter_names[key.lower()]] = kwargs.pop(key)
115119

116120

117-
def _update_dct_if_not_exist(dst: Dict, src: Dict):
121+
def _update_dct_if_not_exist(dst: Dict, src: Dict) -> None:
118122
"""Computes the union of `src` and `dst`, in-place within `dst`
119123
120124
If a key exists in `dst` and `src` the value in `dst` is preserved
@@ -162,17 +166,18 @@ def create_kw_function_from_parameters(
162166
)
163167
default_kwargs = {p.name: p.default for p in parameters}
164168

165-
def f(**kwargs):
169+
def f(**kwargs: Any) -> Any:
166170
# We need to make sure all keys of kwargs are valid.
167171
# Merge valid group keys with original keys.
168172
_assert_arg_valid(kwargs, [*list(default_kwargs.keys()), *flattened_group_keys], func_name=func_name)
169173
# We need to put the default args to the kwargs before invoking the original function.
170174
_update_dct_if_not_exist(kwargs, default_kwargs)
171175
return func(**kwargs)
172176

173-
f = _replace_function_name(f, func_name)
177+
f = _replace_function_name(cast(types.FunctionType, f), func_name)
174178
# Set the signature so jupyter notebook could have param hint by calling inspect.signature()
175-
f.__signature__ = Signature(parameters)
179+
# Bug Item number: 2883223
180+
f.__signature__ = Signature(parameters) # type: ignore
176181
# Set doc/name/module to make sure help(f) shows following expected result.
177182
# Expected help(f):
178183
#
@@ -183,5 +188,5 @@ def f(**kwargs):
183188
f.__doc__ = documentation # Set documentation to update FUNC_DOC in help.
184189
# Set module = None to avoid showing the sentence `in module 'azure.ai.ml.component._dynamic' in help.`
185190
# See https://github.com/python/cpython/blob/2145c8c9724287a310bc77a2760d4f1c0ca9eb0c/Lib/pydoc.py#L1757
186-
f.__module__ = None
191+
f.__module__ = None # type: ignore
187192
return f

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_fl_scatter_gather_node.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import importlib
66

77
# pylint: disable=protected-access
8-
from typing import Dict, List, Optional, Union
8+
from typing import Any, Dict, List, Optional, Union
99

1010
from azure.ai.ml._utils._experimental import experimental
1111
from azure.ai.ml.entities import CommandComponent, PipelineJob
1212
from azure.ai.ml.entities._assets.federated_learning_silo import FederatedLearningSilo
1313
from azure.ai.ml.entities._builders.fl_scatter_gather import FLScatterGather
1414

1515

16-
def _check_for_import(package_name):
16+
def _check_for_import(package_name: str) -> None:
1717
try:
1818
# pylint: disable=unused-import
1919
importlib.import_module(package_name)
@@ -31,16 +31,16 @@ def fl_scatter_gather(
3131
silo_configs: List[FederatedLearningSilo],
3232
silo_component: Union[PipelineJob, CommandComponent],
3333
aggregation_component: Union[PipelineJob, CommandComponent],
34-
aggregation_compute: str = None,
35-
aggregation_datastore: str = None,
34+
aggregation_compute: Optional[str] = None,
35+
aggregation_datastore: Optional[str] = None,
3636
shared_silo_kwargs: Optional[Dict] = None,
3737
aggregation_kwargs: Optional[Dict] = None,
3838
silo_to_aggregation_argument_map: Optional[Dict] = None,
3939
aggregation_to_silo_argument_map: Optional[Dict] = None,
4040
max_iterations: int = 1,
4141
_create_default_mappings_if_needed: bool = False,
42-
**kwargs,
43-
):
42+
**kwargs: Any,
43+
) -> FLScatterGather:
4444
"""A federated learning scatter-gather subgraph node.
4545
4646
It's assumed that this will be used inside of a `@pipeline`-decorated function in order to create a subgraph which

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_group_decorator.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# Attribute on customized group class to mark a value type as a group of inputs/outputs.
88
import _thread
99
import functools
10-
from typing import Any, Callable, Dict, List, Type, TypeVar, Union
10+
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
1111

1212
from azure.ai.ml import Input, Output
1313
from azure.ai.ml.constants._component import IOConstants
@@ -145,12 +145,12 @@ def parent_func(params: ParentClass):
145145

146146
def _create_fn(
147147
name: str,
148-
args: List[str],
149-
body: List[str],
148+
args: Union[List, str],
149+
body: Union[List, str],
150150
*,
151-
globals: Dict[str, Any] = None,
152-
locals: Dict[str, Any] = None,
153-
return_type: Type[T2],
151+
globals: Optional[Dict[str, Any]] = None,
152+
locals: Optional[Dict[str, Any]] = None,
153+
return_type: Optional[Type[T2]],
154154
) -> Callable[..., T2]:
155155
"""To generate function in class.
156156
@@ -188,9 +188,10 @@ def _create_fn(
188188
txt = f" def {name}({args}){return_annotation}:\n{body}"
189189
local_vars = ", ".join(locals.keys())
190190
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
191-
ns = {}
191+
ns: Dict = {}
192192
exec(txt, globals, ns) # pylint: disable=exec-used # nosec
193-
return ns["__create_fn__"](**locals)
193+
res: Callable = ns["__create_fn__"](**locals)
194+
return res
194195

195196
def _create_init_fn( # pylint: disable=unused-argument
196197
cls: Type[T], fields: Dict[str, Union[Annotation, Input, Output]]
@@ -207,7 +208,7 @@ def _create_init_fn( # pylint: disable=unused-argument
207208

208209
# Reference code link:
209210
# https://github.com/python/cpython/blob/17b16e13bb444001534ed6fccb459084596c8bcf/Lib/dataclasses.py#L523
210-
def _get_data_type_from_annotation(anno: Input):
211+
def _get_data_type_from_annotation(anno: Any) -> Any:
211212
if isinstance(anno, GroupInput):
212213
return anno._group_class
213214
# keep original annotation for Outputs
@@ -220,7 +221,7 @@ def _get_data_type_from_annotation(anno: Input):
220221
# otherwise, keep original annotation
221222
return anno
222223

223-
def _get_default(key):
224+
def _get_default(key: str) -> Any:
224225
# will set None as default value when default not exist so won't need to reorder the init params
225226
val = fields[key]
226227
if hasattr(val, "default"):
@@ -254,20 +255,22 @@ def _create_repr_fn(fields: Dict[str, Union[Annotation, Input, Output]]) -> Call
254255
# https://github.com/python/cpython/blob/17b16e13bb444001534ed6fccb459084596c8bcf/Lib/dataclasses.py#L582
255256
fn = _create_fn(
256257
"__repr__",
257-
("self",),
258+
[
259+
"self",
260+
],
258261
['return self.__class__.__qualname__ + f"(' + ", ".join([f"{k}={{self.{k}!r}}" for k in fields]) + ')"'],
259262
return_type=str,
260263
)
261264

262265
# This function's logic is copied from "recursive_repr" function in
263266
# reprlib module to avoid dependency.
264-
def _recursive_repr(user_function):
267+
def _recursive_repr(user_function: Any) -> Any:
265268
# Decorator to make a repr function return "..." for a recursive
266269
# call.
267270
repr_running = set()
268271

269272
@functools.wraps(user_function)
270-
def wrapper(self):
273+
def wrapper(self: Any) -> Any:
271274
key = id(self), _thread.get_ident()
272275
if key in repr_running:
273276
return "..."
@@ -280,7 +283,8 @@ def wrapper(self):
280283

281284
return wrapper
282285

283-
return _recursive_repr(fn)
286+
res: Callable = _recursive_repr(fn)
287+
return res
284288

285289
def _process_class(cls: Type[T], all_fields: Dict[str, Union[Annotation, Input, Output]]) -> Type[T]:
286290
setattr(cls, "__init__", _create_init_fn(cls, all_fields))

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_load_import.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
# pylint: disable=protected-access
66

7-
from typing import Callable
7+
from typing import Any, Callable
88

99
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY
1010
from azure.ai.ml.entities._builders import Command
1111
from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
1212

1313

1414
# pylint: disable=unused-argument
15-
def to_component(*, job: ComponentTranslatableMixin, **kwargs) -> Callable[..., Command]:
15+
def to_component(*, job: ComponentTranslatableMixin, **kwargs: Any) -> Callable[..., Command]:
1616
"""Translate a job object to a component function, provided job should be able to translate to a component.
1717
1818
For example:
@@ -41,4 +41,5 @@ def to_component(*, job: ComponentTranslatableMixin, **kwargs) -> Callable[...,
4141

4242
# set default base path as "./". Because if code path is relative path and base path is None, will raise error when
4343
# get arm id of Code
44-
return job._to_component(context={BASE_PATH_CONTEXT_KEY: Path("./")})
44+
res: Callable = job._to_component(context={BASE_PATH_CONTEXT_KEY: Path("./")})
45+
return res

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_mldesigner/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88
original function/module names the same as before, otherwise mldesigner will be broken by this change.
99
"""
1010

11-
__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
11+
__path__ = __import__("pkgutil").extend_path(__path__, __name__)
1212

13-
from azure.ai.ml.entities._component.component_factory import component_factory
14-
from azure.ai.ml.entities._job.pipeline._load_component import _generate_component_function
15-
from azure.ai.ml.entities._inputs_outputs import _get_param_with_standard_annotation
13+
from azure.ai.ml._internal.entities import InternalComponent
1614
from azure.ai.ml._internal.entities._additional_includes import InternalAdditionalIncludes
1715
from azure.ai.ml._utils._asset_utils import get_ignore_file
1816
from azure.ai.ml._utils.utils import try_enable_internal_components
19-
from azure.ai.ml._internal.entities import InternalComponent
2017
from azure.ai.ml.dsl._condition import condition
2118
from azure.ai.ml.dsl._do_while import do_while
22-
from azure.ai.ml.dsl._parallel_for import parallel_for, ParallelFor
2319
from azure.ai.ml.dsl._group_decorator import group
20+
from azure.ai.ml.dsl._parallel_for import ParallelFor, parallel_for
21+
from azure.ai.ml.entities._component.component_factory import component_factory
22+
from azure.ai.ml.entities._inputs_outputs import _get_param_with_standard_annotation
23+
from azure.ai.ml.entities._job.pipeline._load_component import _generate_component_function
2424

2525
from ._constants import V1_COMPONENT_TO_NODE
2626

sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_overrides_definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
44

5-
from typing import Mapping
5+
from typing import Mapping, Optional
66

77

88
class OverrideDefinition(dict):
@@ -11,7 +11,7 @@ class OverrideDefinition(dict):
1111

1212
def get_override_definition_from_schema(
1313
schema: str, # pylint: disable=unused-argument
14-
) -> Mapping[str, OverrideDefinition]:
14+
) -> Optional[Mapping[str, OverrideDefinition]]:
1515
"""Ger override definition from a json schema.
1616
1717
:param schema: Json schema of component job.

0 commit comments

Comments
 (0)