Skip to content

Commit 7eedc68

Browse files
Backport changes for 14.1 (#3006)
Signed-off-by: taieeuu <[email protected]> Signed-off-by: Thomas Newton <[email protected]> Signed-off-by: Samhita Alla <[email protected]> Signed-off-by: Yee Hing Tong <[email protected]> Signed-off-by: Eduardo Apolinario <[email protected]>
1 parent 03cec6b commit 7eedc68

File tree

9 files changed

+121
-41
lines changed

9 files changed

+121
-41
lines changed

.github/workflows/pythonbuild.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,9 @@ jobs:
299299
FLYTEKIT_IMAGE: localhost:30000/flytekit:dev
300300
FLYTEKIT_CI: 1
301301
PYTEST_OPTS: -n2
302+
AWS_ENDPOINT_URL: 'http://localhost:30002'
303+
AWS_ACCESS_KEY_ID: minio
304+
AWS_SECRET_ACCESS_KEY: miniostorage
302305
run: |
303306
make ${{ matrix.makefile-cmd }}
304307
- name: Codecov

flytekit/core/data_persistence.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -423,47 +423,34 @@ async def async_put_raw_data(
423423
r = await self._put(from_path, to_path, **kwargs)
424424
return r or to_path
425425

426+
# See https://github.com/fsspec/s3fs/issues/871 for more background and pending work on the fsspec side to
427+
# support effectively async open(). For now these use-cases below will revert to sync calls.
426428
# raw bytes
427429
if isinstance(lpath, bytes):
428-
fs = await self.get_async_filesystem_for_path(to_path)
429-
if isinstance(fs, AsyncFileSystem):
430-
async with fs.open_async(to_path, "wb", **kwargs) as s:
431-
s.write(lpath)
432-
else:
433-
with fs.open(to_path, "wb", **kwargs) as s:
434-
s.write(lpath)
435-
430+
fs = self.get_filesystem_for_path(to_path)
431+
with fs.open(to_path, "wb", **kwargs) as s:
432+
s.write(lpath)
436433
return to_path
437434

438435
# If lpath is a buffered reader of some kind
439436
if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO):
440437
if not lpath.readable():
441438
raise FlyteAssertion("Buffered reader must be readable")
442-
fs = await self.get_async_filesystem_for_path(to_path)
439+
fs = self.get_filesystem_for_path(to_path)
443440
lpath.seek(0)
444-
if isinstance(fs, AsyncFileSystem):
445-
async with fs.open_async(to_path, "wb", **kwargs) as s:
446-
while data := lpath.read(read_chunk_size_bytes):
447-
s.write(data)
448-
else:
449-
with fs.open(to_path, "wb", **kwargs) as s:
450-
while data := lpath.read(read_chunk_size_bytes):
451-
s.write(data)
441+
with fs.open(to_path, "wb", **kwargs) as s:
442+
while data := lpath.read(read_chunk_size_bytes):
443+
s.write(data)
452444
return to_path
453445

454446
if isinstance(lpath, io.StringIO):
455447
if not lpath.readable():
456448
raise FlyteAssertion("Buffered reader must be readable")
457-
fs = await self.get_async_filesystem_for_path(to_path)
449+
fs = self.get_filesystem_for_path(to_path)
458450
lpath.seek(0)
459-
if isinstance(fs, AsyncFileSystem):
460-
async with fs.open_async(to_path, "wb", **kwargs) as s:
461-
while data_str := lpath.read(read_chunk_size_bytes):
462-
s.write(data_str.encode(encoding))
463-
else:
464-
with fs.open(to_path, "wb", **kwargs) as s:
465-
while data_str := lpath.read(read_chunk_size_bytes):
466-
s.write(data_str.encode(encoding))
451+
with fs.open(to_path, "wb", **kwargs) as s:
452+
while data_str := lpath.read(read_chunk_size_bytes):
453+
s.write(data_str.encode(encoding))
467454
return to_path
468455

469456
raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}")

flytekit/remote/remote.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,9 @@ def get_launch_plan_from_then_node(
524524

525525
if node.branch_node:
526526
get_launch_plan_from_branch(node.branch_node, node_launch_plans)
527-
return FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans)
527+
flyte_workflow = FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans)
528+
flyte_workflow.template._id = workflow_id
529+
return flyte_workflow
528530

529531
def _upgrade_launchplan(self, lp: launch_plan_models.LaunchPlan) -> FlyteLaunchPlan:
530532
"""
@@ -863,13 +865,17 @@ async def _serialize_and_register(
863865
cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items()))
864866
tasks = []
865867
loop = asyncio.get_running_loop()
866-
for entity, cp_entity in cp_task_entity_map.items():
868+
for task_entity, cp_entity in cp_task_entity_map.items():
867869
tasks.append(
868870
loop.run_in_executor(
869871
None,
870-
functools.partial(self.raw_register, cp_entity, serialization_settings, version, og_entity=entity),
872+
functools.partial(
873+
self.raw_register, cp_entity, serialization_settings, version, og_entity=task_entity
874+
),
871875
)
872876
)
877+
if task_entity == entity:
878+
registered_entity = await tasks[-1]
873879

874880
identifiers_or_exceptions = []
875881
identifiers_or_exceptions.extend(await asyncio.gather(*tasks, return_exceptions=True))
@@ -882,15 +888,17 @@ async def _serialize_and_register(
882888
raise ie
883889
# serial register
884890
cp_other_entities = OrderedDict(filter(lambda x: not isinstance(x[1], task_models.TaskSpec), m.items()))
885-
for entity, cp_entity in cp_other_entities.items():
891+
for non_task_entity, cp_entity in cp_other_entities.items():
886892
try:
887893
identifiers_or_exceptions.append(
888-
self.raw_register(cp_entity, serialization_settings, version, og_entity=entity)
894+
self.raw_register(cp_entity, serialization_settings, version, og_entity=non_task_entity)
889895
)
890896
except RegistrationSkipped as e:
891897
logger.info(f"Skipping registration... {e}")
892898
continue
893-
return identifiers_or_exceptions[-1]
899+
if non_task_entity == entity:
900+
registered_entity = identifiers_or_exceptions[-1]
901+
return registered_entity
894902

895903
def register_task(
896904
self,

plugins/flytekit-dbt/setup.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,7 @@
44

55
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"
66

7-
plugin_requires = [
8-
"flytekit>=1.3.0b2",
9-
"dbt-core<1.8.0",
10-
]
7+
plugin_requires = ["flytekit>=1.3.0b2", "dbt-core>=1.6.0,<1.8.0", "networkx>=2.5"]
118

129
__version__ = "0.0.0+develop"
1310

plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def __init__(self, *args, **kwargs):
164164
name=container_name,
165165
image="python:3.11-slim",
166166
command=["/bin/sh", "-c"],
167-
args=[f"pip install requests && pip install ollama && {command}"],
167+
args=[f"pip install requests && pip install ollama==0.3.3 && {command}"],
168168
resources=V1ResourceRequirements(
169169
requests={
170170
"cpu": self._model_cpu,

plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ def encode(
6969
df.to_parquet(output_bytes)
7070

7171
if structured_dataset.uri is not None:
72+
output_bytes.seek(0)
7273
fs = ctx.file_access.get_filesystem_for_path(path=structured_dataset.uri)
7374
with fs.open(structured_dataset.uri, "wb") as s:
74-
s.write(output_bytes)
75+
s.write(output_bytes.read())
7576
output_uri = structured_dataset.uri
7677
else:
7778
remote_fn = "00000" # 00000 is our default unnamed parquet filename

plugins/flytekit-polars/tests/test_polars_plugin_sd.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer
77
from typing_extensions import Annotated
8-
from packaging import version
8+
import numpy as np
99
from polars.testing import assert_frame_equal
1010

1111
from flytekit import kwtypes, task, workflow
@@ -134,3 +134,28 @@ def consume_sd_return_sd(sd: StructuredDataset) -> StructuredDataset:
134134
opened_sd = opened_sd.collect()
135135

136136
assert_frame_equal(opened_sd, polars_df)
137+
138+
139+
def test_with_uri():
140+
temp_file = tempfile.mktemp()
141+
142+
@task
143+
def random_dataframe(num_rows: int) -> StructuredDataset:
144+
feature_1_list = np.random.randint(low=100, high=999, size=(num_rows,))
145+
feature_2_list = np.random.normal(loc=0, scale=1, size=(num_rows, ))
146+
pl_df = pl.DataFrame({'protein_length': feature_1_list,
147+
'protein_feature': feature_2_list})
148+
sd = StructuredDataset(dataframe=pl_df, uri=temp_file)
149+
return sd
150+
151+
@task
152+
def consume(df: pd.DataFrame):
153+
print(df.head(5))
154+
print(df.describe())
155+
156+
@workflow
157+
def my_wf(num_rows: int):
158+
pl = random_dataframe(num_rows=num_rows)
159+
consume(pl)
160+
161+
my_wf(num_rows=100)

tests/flytekit/unit/core/test_data_persistence.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import io
22
import os
3-
import fsspec
43
import pathlib
54
import random
65
import string
76
import sys
87
import tempfile
98

9+
import fsspec
1010
import mock
1111
import pytest
1212
from azure.identity import ClientSecretCredential, DefaultAzureCredential
1313

14+
from flytekit.configuration import Config
1415
from flytekit.core.data_persistence import FileAccessProvider
1516
from flytekit.core.local_fsspec import FlyteLocalFileSystem
1617

@@ -207,3 +208,18 @@ def __init__(self, *args, **kwargs):
207208

208209
fp = FileAccessProvider("/tmp", "s3://my-bucket")
209210
fp.get_filesystem("testgetfs", test_arg="test_arg")
211+
212+
213+
@pytest.mark.sandbox_test
214+
def test_put_raw_data_bytes():
215+
dc = Config.for_sandbox().data_config
216+
raw_output = f"s3://my-s3-bucket/"
217+
provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc)
218+
prefix = provider.get_random_string()
219+
provider.put_raw_data(lpath=b"hello", upload_prefix=prefix, file_name="hello_bytes")
220+
provider.put_raw_data(lpath=io.BytesIO(b"hello"), upload_prefix=prefix, file_name="hello_bytes_io")
221+
provider.put_raw_data(lpath=io.StringIO("hello"), upload_prefix=prefix, file_name="hello_string_io")
222+
223+
fs = provider.get_filesystem("s3")
224+
listing = fs.ls(f"{raw_output}{prefix}/")
225+
assert len(listing) == 3

tests/flytekit/unit/remote/test_remote.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier
3636
from flytekit.models.execution import Execution
3737
from flytekit.models.task import Task
38-
from flytekit.remote import FlyteTask
38+
from flytekit.remote import FlyteTask, FlyteWorkflow
3939
from flytekit.remote.lazy_entity import LazyEntity
4040
from flytekit.remote.remote import FlyteRemote, _get_git_repo_url, _get_pickled_target_dict
4141
from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan
@@ -811,3 +811,46 @@ def wf() -> int:
811811
# the second one should
812812
rr.register_launch_plan(lp2, version="1", serialization_settings=ss)
813813
mock_client.update_launch_plan.assert_called()
814+
815+
816+
@mock.patch("flytekit.remote.remote.FlyteRemote.client")
817+
def test_register_task_with_node_dependency_hints(mock_client):
818+
@task
819+
def task0():
820+
return None
821+
822+
@workflow
823+
def workflow0():
824+
return task0()
825+
826+
@dynamic(node_dependency_hints=[workflow0])
827+
def dynamic0():
828+
return workflow0()
829+
830+
@workflow
831+
def workflow1():
832+
return dynamic0()
833+
834+
rr = FlyteRemote(
835+
Config.for_sandbox(),
836+
default_project="flytesnacks",
837+
default_domain="development",
838+
)
839+
840+
ss = SerializationSettings(
841+
image_config=ImageConfig.from_images("docker.io/abc:latest"),
842+
version="dummy_version",
843+
)
844+
845+
registered_task = rr.register_task(dynamic0, ss)
846+
assert isinstance(registered_task, FlyteTask)
847+
assert registered_task.id.resource_type == ResourceType.TASK
848+
assert registered_task.id.project == "flytesnacks"
849+
assert registered_task.id.domain == "development"
850+
# When running via `make unit_test` there is a `__-channelexec__` prefix added to the name.
851+
assert registered_task.id.name.endswith("tests.flytekit.unit.remote.test_remote.dynamic0")
852+
assert registered_task.id.version == "dummy_version"
853+
854+
registered_workflow = rr.register_workflow(workflow1, ss)
855+
assert isinstance(registered_workflow, FlyteWorkflow)
856+
assert registered_workflow.id == Identifier(ResourceType.WORKFLOW, "flytesnacks", "development", "tests.flytekit.unit.remote.test_remote.workflow1", "dummy_version")

0 commit comments

Comments
 (0)