Skip to content

Commit 006493a

Browse files
committed
Merge branch 'master' of github.com:flyteorg/flytekit into 6277-map-container
-e Signed-off-by: machichima <nary12321@gmail.com>
2 parents 4827e55 + 764b36f commit 006493a

File tree

12 files changed

+298
-65
lines changed

12 files changed

+298
-65
lines changed

flytekit/clients/auth/token_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def poll_token_endpoint(
159159
token_endpoint: str,
160160
client_id: str,
161161
audience: typing.Optional[str] = None,
162-
scopes: typing.Optional[str] = None,
162+
scopes: typing.Optional[typing.List[str]] = None,
163163
http_proxy_url: typing.Optional[str] = None,
164164
verify: typing.Optional[typing.Union[bool, str]] = None,
165165
) -> typing.Tuple[str, str, int]:

flytekit/clis/sdk_in_container/run.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,8 @@ def to_click_option(
472472
python_type=python_type,
473473
is_remote=run_level_params.is_remote,
474474
)
475-
476475
if literal_converter.is_bool() and not default_val:
477476
default_val = False
478-
479477
description_extra = ""
480478
if literal_var.type.simple == SimpleType.STRUCT:
481479
if default_val and not isinstance(default_val, ArtifactQuery):
@@ -517,6 +515,7 @@ def to_click_option(
517515
required=required,
518516
help=literal_var.description + description_extra,
519517
callback=literal_converter.convert,
518+
is_flag=literal_converter.is_bool(),
520519
)
521520

522521

@@ -870,12 +869,17 @@ def invoke(self, ctx: click.Context) -> typing.Any:
870869
run_level_params: RunLevelParams = ctx.obj
871870
r = run_level_params.remote_instance()
872871
entity = self._fetch_entity(ctx)
872+
873+
param_defaults = {param.name: param.default for param in ctx.command.params}
874+
875+
filtered_inputs = {k: param_defaults.get(k, v) if v is None else v for k, v in ctx.params.items()}
876+
873877
run_remote(
874878
r,
875879
entity,
876880
run_level_params.project,
877881
run_level_params.domain,
878-
ctx.params,
882+
filtered_inputs,
879883
run_level_params,
880884
type_hints=entity.python_interface.inputs if entity.python_interface else None,
881885
)

flytekit/core/type_engine.py

Lines changed: 71 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import typing
1717
from abc import ABC, abstractmethod
1818
from collections import OrderedDict
19-
from functools import lru_cache
19+
from functools import lru_cache, reduce
2020
from types import GenericAlias
2121
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast
2222

@@ -1089,56 +1089,78 @@ def assert_type(self, t: Type[enum.Enum], v: T):
10891089
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")
10901090

10911091

1092-
def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
1092+
def _handle_json_schema_property(
1093+
property_key: str,
1094+
property_val: dict,
1095+
) -> typing.Tuple[str, typing.Any]:
1096+
"""
1097+
A helper to handle the properties of a JSON schema and returns their equivalent Flyte attribute name and type.
1098+
"""
1099+
1100+
# Handle Optional[T] or Union[T1, T2, ...] at the top level for proper recursion
1101+
if property_val.get("anyOf"):
1102+
# Sanity check 'anyOf' is not empty
1103+
assert len(property_val["anyOf"]) > 0
1104+
# Check that there are no nested Optional or Union types - no need to support that pattern
1105+
# as it would just add complexity without much benefit
1106+
# A few examples: Optional[Optional[T]] or Union[T1, T2, Union[T3, T4], etc...]
1107+
if any(item.get("anyOf") for item in property_val["anyOf"]):
1108+
raise ValueError(
1109+
f"The property with name {property_key} has a nested Optional or Union type, this is not allowed for dataclass JSON deserialization."
1110+
)
1111+
attr_types = []
1112+
for item in property_val["anyOf"]:
1113+
_, attr_type = _handle_json_schema_property(property_key, item)
1114+
attr_types.append(attr_type)
1115+
1116+
# Gather all the types and return a Union[T1, T2, ...]
1117+
attr_union_type = reduce(lambda x, y: typing.Union[x, y], attr_types)
1118+
return (property_key, attr_union_type) # type: ignore
1119+
1120+
# Handle enum
1121+
if property_val.get("enum"):
1122+
property_type = "enum"
1123+
else:
1124+
property_type = property_val["type"]
1125+
1126+
# Handle list
1127+
if property_type == "array":
1128+
return (property_key, typing.List[_get_element_type(property_val["items"])]) # type: ignore
1129+
# Handle null types (i.e. None)
1130+
elif property_type == "null":
1131+
return (property_key, type(None)) # type: ignore
1132+
# Handle dataclass and dict
1133+
elif property_type == "object":
1134+
# NOTE: No need to handle optional dataclasses here (i.e. checking for property_val.get("anyOf"))
1135+
# those are handled in the top level of the function with recursion.
1136+
if property_val.get("additionalProperties"):
1137+
# For typing.Dict type
1138+
elem_type = _get_element_type(property_val["additionalProperties"])
1139+
return (property_key, typing.Dict[str, elem_type]) # type: ignore
1140+
elif property_val.get("title"):
1141+
# For nested dataclass
1142+
sub_schema_name = property_val["title"]
1143+
return (
1144+
property_key,
1145+
typing.cast(GenericAlias, convert_mashumaro_json_schema_to_python_class(property_val, sub_schema_name)),
1146+
)
1147+
else:
1148+
# For untyped dict
1149+
return (property_key, dict) # type: ignore
1150+
elif property_type == "enum":
1151+
return (property_key, str) # type: ignore
1152+
# Handle None, int, float, bool or str
1153+
else:
1154+
return (property_key, _get_element_type(property_val)) # type: ignore
1155+
1156+
1157+
def generate_attribute_list_from_dataclass_json_mixin(
1158+
schema: dict,
1159+
schema_name: typing.Any,
1160+
):
10931161
attribute_list: typing.List[typing.Tuple[Any, Any]] = []
10941162
for property_key, property_val in schema["properties"].items():
1095-
property_type = ""
1096-
if property_val.get("anyOf"):
1097-
property_type = property_val["anyOf"][0]["type"]
1098-
elif property_val.get("enum"):
1099-
property_type = "enum"
1100-
else:
1101-
property_type = property_val["type"]
1102-
# Handle list
1103-
if property_type == "array":
1104-
attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore
1105-
# Handle dataclass and dict
1106-
elif property_type == "object":
1107-
if property_val.get("anyOf"):
1108-
# For optional with dataclass
1109-
sub_schemea = property_val["anyOf"][0]
1110-
sub_schemea_name = sub_schemea["title"]
1111-
attribute_list.append(
1112-
(
1113-
property_key,
1114-
typing.cast(
1115-
GenericAlias, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)
1116-
),
1117-
)
1118-
)
1119-
elif property_val.get("additionalProperties"):
1120-
# For typing.Dict type
1121-
elem_type = _get_element_type(property_val["additionalProperties"])
1122-
attribute_list.append((property_key, typing.Dict[str, elem_type])) # type: ignore
1123-
elif property_val.get("title"):
1124-
# For nested dataclass
1125-
sub_schemea_name = property_val["title"]
1126-
attribute_list.append(
1127-
(
1128-
property_key,
1129-
typing.cast(
1130-
GenericAlias, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)
1131-
),
1132-
)
1133-
)
1134-
else:
1135-
# For untyped dict
1136-
attribute_list.append((property_key, dict)) # type: ignore
1137-
elif property_type == "enum":
1138-
attribute_list.append([property_key, str]) # type: ignore
1139-
# Handle int, float, bool or str
1140-
else:
1141-
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
1163+
attribute_list.append(_handle_json_schema_property(property_key, property_val))
11421164
return attribute_list
11431165

11441166

flytekit/types/directory/types.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
import os
56
import pathlib
@@ -22,15 +23,24 @@
2223

2324
from flytekit.core.constants import MESSAGEPACK
2425
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
25-
from flytekit.core.type_engine import AsyncTypeTransformer, TypeEngine, TypeTransformerFailedError, get_batch_size
26+
from flytekit.core.type_engine import (
27+
AsyncTypeTransformer,
28+
TypeEngine,
29+
TypeTransformerFailedError,
30+
get_batch_size,
31+
)
2632
from flytekit.exceptions.user import FlyteAssertion
27-
from flytekit.extras.pydantic_transformer.decorator import model_serializer, model_validator
33+
from flytekit.extras.pydantic_transformer.decorator import (
34+
model_serializer,
35+
model_validator,
36+
)
2837
from flytekit.models import types as _type_models
2938
from flytekit.models.core import types as _core_types
3039
from flytekit.models.core.types import BlobType
3140
from flytekit.models.literals import Binary, Blob, BlobMetadata, Literal, Scalar
3241
from flytekit.models.types import LiteralType
3342
from flytekit.types.file import FileExt, FlyteFile
43+
from flytekit.utils.asyn import loop_manager
3444

3545
T = typing.TypeVar("T")
3646
PathType = typing.Union[str, os.PathLike]
@@ -193,7 +203,19 @@ def __fspath__(self):
193203
This function should be called by os.listdir as well.
194204
"""
195205
if not self._downloaded:
196-
self._downloader()
206+
if isinstance(self._downloader, partial):
207+
underlying_func = self._downloader.func
208+
else:
209+
underlying_func = self._downloader
210+
211+
if asyncio.iscoroutinefunction(underlying_func):
212+
# If the downloader is a coroutine function, we need to run it in the event loop.
213+
# This is required when using the Elastic task config with the start method set to 'spawn'.
214+
# This is possibly due to pickling and unpickling the task function, which may not properly rebind to the synced version.
215+
loop_manager.synced(self._downloader)()
216+
else:
217+
# If the downloader is not a coroutine, we can just call it directly.
218+
self._downloader()
197219
self._downloaded = True
198220
return self.path
199221

plugins/flytekit-onnx-pytorch/dev-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ humanfriendly==10.0
2222
# via coloredlogs
2323
idna==3.4
2424
# via requests
25-
jinja2==3.1.5
25+
jinja2==3.1.6
2626
# via torch
2727
markupsafe==2.1.3
2828
# via jinja2

plugins/flytekit-onnx-scikitlearn/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"
66

7-
plugin_requires = ["flytekit", "skl2onnx>=1.10.3", "networkx<3.2; python_version<'3.9'"]
7+
plugin_requires = ["flytekit", "skl2onnx>=1.10.3,<1.19.0", "networkx<3.2; python_version<'3.9'", "onnx<1.18.0"]
88

99
__version__ = "0.0.0+develop"
1010

plugins/flytekit-spark/setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"
66

7-
plugin_requires = ["flytekit>=1.15.1", "pyspark>=3.4.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"]
7+
# TODO: Add support spark 4.0.0, https://github.com/flyteorg/flyte/issues/6478
8+
plugin_requires = ["flytekit>=1.15.1", "pyspark>=3.4.0,<4.0.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"]
89

910
__version__ = "0.0.0+develop"
1011

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ dependencies = [
2222
"docstring-parser>=0.9.0",
2323
"flyteidl>=1.15.1",
2424
"fsspec>=2023.3.0",
25-
"gcsfs>=2023.3.0",
25+
"gcsfs>=2023.3.0,!=2025.5.0,!=2025.5.0post1", # Bug in 2025.5.0, 2025.5.0post1 https://github.com/fsspec/gcsfs/issues/687
2626
"googleapis-common-protos>=1.57",
2727
"grpcio",
2828
"grpcio-status",

tests/flytekit/integration/remote/test_remote.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import uuid
1717
import pytest
1818
from unittest import mock
19-
from dataclasses import dataclass
19+
from dataclasses import asdict, dataclass
2020

2121
from flytekit import LaunchPlan, kwtypes, WorkflowExecutionPhase, task, workflow
2222
from flytekit.configuration import Config, ImageConfig, SerializationSettings
@@ -1148,6 +1148,88 @@ def test_execute_workflow_with_dataclass():
11481148
)
11491149
assert out.outputs["o0"] == ""
11501150

1151+
def test_execute_wf_out_dataclass_with_optional():
1152+
"""Test remote execution of a workflow outputting a dataclass where optional fields are present."""
1153+
from tests.flytekit.integration.remote.workflows.basic.dataclass_with_optional_wf import MyDataClassWithOptional, MyParentDataClass
1154+
1155+
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
1156+
1157+
# Simple case where dataclasses are not nested
1158+
in_dataclass = MyDataClassWithOptional(foo={"a": 1.0, "b": 2.0}, bar={"c": 3.0, "d": 4.0})
1159+
1160+
execution_id = run(
1161+
"dataclass_with_optional_wf.py",
1162+
"wf",
1163+
"--in_dataclass",
1164+
json.dumps(asdict(in_dataclass)),
1165+
)
1166+
1167+
execution = remote.fetch_execution(name=execution_id, project=PROJECT, domain=DOMAIN)
1168+
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=10))
1169+
1170+
assert asdict(execution.outputs["o0"]) == {
1171+
"foo": {"a": 1.0, "b": 2.0},
1172+
"bar": {"c": 3.0, "d": 4.0},
1173+
"baz": None,
1174+
"qux": None,
1175+
}
1176+
1177+
# Case where dataclasses are nested (all fields are populated)
1178+
in_dataclass = MyParentDataClass(
1179+
child=MyDataClassWithOptional(foo={"a": 1.0, "b": 2.0}, bar={"c": 3.0, "d": 4.0}),
1180+
a={"a": 1.0, "b": 2.0},
1181+
b=MyDataClassWithOptional(foo={"a": 1.0, "b": 2.0}, bar={"c": 3.0, "d": 4.0}),
1182+
)
1183+
1184+
execution_id = run(
1185+
"dataclass_with_optional_wf.py",
1186+
"wf_nested_dc",
1187+
"--in_dataclass",
1188+
json.dumps(asdict(in_dataclass)),
1189+
)
1190+
1191+
execution = remote.fetch_execution(name=execution_id, project=PROJECT, domain=DOMAIN)
1192+
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=10))
1193+
assert asdict(execution.outputs["o0"]) == {
1194+
"child": {
1195+
"foo": {"a": 1.0, "b": 2.0},
1196+
"bar": {"c": 3.0, "d": 4.0},
1197+
"baz": None,
1198+
"qux": None,
1199+
},
1200+
"a": {"a": 1.0, "b": 2.0},
1201+
"b": {
1202+
"foo": {"a": 1.0, "b": 2.0},
1203+
"bar": {"c": 3.0, "d": 4.0},
1204+
"baz": None,
1205+
"qux": None,
1206+
}
1207+
}
1208+
1209+
# Case where dataclasses are nested (optionals are left as None)
1210+
in_dataclass = MyParentDataClass(
1211+
child=MyDataClassWithOptional(foo={"a": 1.0, "b": 2.0}, bar={"c": 3.0, "d": 4.0}),
1212+
)
1213+
1214+
execution_id = run(
1215+
"dataclass_with_optional_wf.py",
1216+
"wf_nested_dc",
1217+
"--in_dataclass",
1218+
json.dumps(asdict(in_dataclass)),
1219+
)
1220+
1221+
execution = remote.fetch_execution(name=execution_id, project=PROJECT, domain=DOMAIN)
1222+
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=10))
1223+
assert asdict(execution.outputs["o0"]) == {
1224+
"child": {
1225+
"foo": {"a": 1.0, "b": 2.0},
1226+
"bar": {"c": 3.0, "d": 4.0},
1227+
"baz": None,
1228+
"qux": None,
1229+
},
1230+
"a": None,
1231+
"b": None,
1232+
}
11511233

11521234
def test_register_wf_twice(register):
11531235
# Register the same workflow again should not raise an error

0 commit comments

Comments
 (0)