Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mashumaro.codecs.json import JSONEncoder
from rich.progress import Progress, TextColumn, TimeElapsedColumn
from typing_extensions import get_origin
from typing_inspect import is_optional_type

from flytekit import Annotations, FlyteContext, FlyteContextManager, Labels, LaunchPlan, Literal, WorkflowExecutionPhase
from flytekit.clis.sdk_in_container.helpers import (
Expand Down Expand Up @@ -696,9 +697,16 @@ def _run(*args, **kwargs):
inputs = {}
for input_name, v in entity.python_interface.inputs_with_defaults.items():
processed_click_value = kwargs.get(input_name)
skip_default_value_selection = False
if (
is_optional_type(v[0])
and isinstance(processed_click_value, str)
and processed_click_value.lower() == "none"
):
processed_click_value = None
skip_default_value_selection = True
optional_v = False

skip_default_value_selection = False
if processed_click_value is None and isinstance(v, typing.Tuple):
if entity_type == "workflow" and hasattr(v[0], "__args__"):
origin_base_type = get_origin(v[0])
Expand Down Expand Up @@ -730,6 +738,8 @@ def _run(*args, **kwargs):
inputs[input_name] = processed_click_value
if processed_click_value is None and v[0] == bool:
inputs[input_name] = False
if processed_click_value is None and is_optional(v[0]):
inputs[input_name] = None

if not run_level_params.is_remote:
with FlyteContextManager.with_context(_update_flyte_context(run_level_params)):
Expand Down
10 changes: 10 additions & 0 deletions flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataclasses_json import DataClassJsonMixin, dataclass_json
from packaging.version import Version
from pytimeparse import parse
from typing_inspect import is_optional_type

from flytekit import BlobType, FlyteContext, Literal, LiteralType, StructuredDataset
from flytekit.core.artifact import ArtifactQuery
Expand Down Expand Up @@ -572,6 +573,8 @@ def convert(
if not self._is_remote:
return value

if is_optional_type(self._python_type) and isinstance(value, str) and value.lower() == "none":
value = None
lit = TypeEngine.to_literal(self._flyte_ctx, value, self._python_type, self._literal_type)
return lit
except click.BadParameter:
Expand All @@ -581,3 +584,10 @@ def convert(
f"Failed to convert param: {param if param else 'NA'}, value: {value} to type: {self._python_type}."
f" Reason {e}"
) from e


def is_optional(_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use is_optional_type from typing_inspect.

from typing_inspect import is_optional_type

"""
Checks if the given type is Optional Type
"""
return typing.get_origin(_type) is typing.Union and type(None) in typing.get_args(_type)
Loading