diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index eb8164f2b0..39d0135f44 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -17,6 +17,7 @@ from mashumaro.codecs.json import JSONEncoder from rich.progress import Progress, TextColumn, TimeElapsedColumn from typing_extensions import get_origin +from typing_inspect import is_optional_type from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, LaunchPlan, Literal, WorkflowExecutionPhase from flytekit.clis.sdk_in_container.helpers import ( @@ -696,9 +697,16 @@ def _run(*args, **kwargs): inputs = {} for input_name, v in entity.python_interface.inputs_with_defaults.items(): processed_click_value = kwargs.get(input_name) + skip_default_value_selection = False + if ( + is_optional_type(v[0]) + and isinstance(processed_click_value, str) + and processed_click_value.lower() == "none" + ): + processed_click_value = None + skip_default_value_selection = True optional_v = False - skip_default_value_selection = False if processed_click_value is None and isinstance(v, typing.Tuple): if entity_type == "workflow" and hasattr(v[0], "__args__"): origin_base_type = get_origin(v[0]) @@ -730,6 +738,8 @@ def _run(*args, **kwargs): inputs[input_name] = processed_click_value if processed_click_value is None and v[0] == bool: inputs[input_name] = False + if processed_click_value is None and is_optional(v[0]): + inputs[input_name] = None if not run_level_params.is_remote: with FlyteContextManager.with_context(_update_flyte_context(run_level_params)): diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py index 056fa2db61..0efc821a75 100644 --- a/flytekit/interaction/click_types.py +++ b/flytekit/interaction/click_types.py @@ -19,6 +19,7 @@ from dataclasses_json import DataClassJsonMixin, dataclass_json from packaging.version import Version from pytimeparse import parse +from typing_inspect import is_optional_type from flytekit import BlobType, FlyteContext, Literal, LiteralType, StructuredDataset from flytekit.core.artifact import ArtifactQuery @@ -572,6 +573,8 @@ def convert( if not self._is_remote: return value + if is_optional_type(self._python_type) and isinstance(value, str) and value.lower() == "none": + value = None lit = TypeEngine.to_literal(self._flyte_ctx, value, self._python_type, self._literal_type) return lit except click.BadParameter: @@ -581,3 +584,10 @@ def convert( f"Failed to convert param: {param if param else 'NA'}, value: {value} to type: {self._python_type}." f" Reason {e}" ) from e + + +def is_optional(_type): + """ + Checks if the given type is Optional Type + """ + return typing.get_origin(_type) is typing.Union and type(None) in typing.get_args(_type)