Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
from flytekit.image_spec import ImageSpec
from flytekit.loggers import LOGGING_RICH_FMT_ENV_VAR, logger
from flytekit.models.common import Annotations, AuthRole, Labels
from flytekit.models.concurrency import ConcurrencyLimitBehavior, ConcurrencyPolicy
from flytekit.models.core.execution import WorkflowExecutionPhase
from flytekit.models.core.types import BlobType
from flytekit.models.documentation import Description, Documentation, SourceCode
Expand Down
16 changes: 14 additions & 2 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from flytekit.models import literals as _literal_models
from flytekit.models import schedule as _schedule_model
from flytekit.models import security
from flytekit.models.concurrency import ConcurrencyPolicy
from flytekit.models.core import workflow as _workflow_model


Expand Down Expand Up @@ -114,6 +115,7 @@ def create(
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
auto_activate: bool = False,
concurrency: Optional[ConcurrencyPolicy] = None,
) -> LaunchPlan:
ctx = FlyteContextManager.current_context()
default_inputs = default_inputs or {}
Expand Down Expand Up @@ -167,6 +169,7 @@ def create(
trigger=trigger,
overwrite_cache=overwrite_cache,
auto_activate=auto_activate,
concurrency=concurrency,
)

# This is just a convenience - we'll need the fixed inputs LiteralMap for when serializing the Launch Plan out
Expand Down Expand Up @@ -198,6 +201,7 @@ def get_or_create(
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
auto_activate: bool = False,
concurrency: Optional[ConcurrencyPolicy] = None,
) -> LaunchPlan:
"""
This function offers a friendlier interface for creating launch plans. If the name for the launch plan is not
Expand Down Expand Up @@ -225,8 +229,8 @@ def get_or_create(
parallelism/concurrency of MapTasks is independent from this.
:param trigger: [alpha] This is a new syntax for specifying schedules.
:param overwrite_cache: If set to True, the execution will always overwrite cache
:param auto_activate: If set to True, the launch plan will be activated automatically on registration.
Default is False.
:param auto_activate: If set to True, the launch plan will be activated automatically on registration. Default is False.
:param concurrency: Defines execution concurrency limits and policy when limit is reached
"""
if name is None and (
default_inputs is not None
Expand Down Expand Up @@ -279,6 +283,7 @@ def get_or_create(
("security_context", security_context, cached_outputs["_security_context"]),
("overwrite_cache", overwrite_cache, cached_outputs["_overwrite_cache"]),
("auto_activate", auto_activate, cached_outputs["_auto_activate"]),
("concurrency", concurrency, cached_outputs["_concurrency"]),
]:
if new != cached:
raise AssertionError(
Expand Down Expand Up @@ -311,6 +316,7 @@ def get_or_create(
trigger=trigger,
overwrite_cache=overwrite_cache,
auto_activate=auto_activate,
concurrency=concurrency,
)
LaunchPlan.CACHE[name or workflow.name] = lp
return lp
Expand All @@ -331,6 +337,7 @@ def __init__(
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
auto_activate: bool = False,
concurrency: Optional[ConcurrencyPolicy] = None,
):
self._name = name
self._workflow = workflow
Expand All @@ -351,6 +358,7 @@ def __init__(
self._trigger = trigger
self._overwrite_cache = overwrite_cache
self._auto_activate = auto_activate
self._concurrency = concurrency

FlyteEntities.entities.append(self)

Expand Down Expand Up @@ -455,6 +463,10 @@ def security_context(self) -> Optional[security.SecurityContext]:
def trigger(self) -> Optional[LaunchPlanTriggerBase]:
return self._trigger

@property
def concurrency(self) -> Optional[ConcurrencyPolicy]:
return self._concurrency

@property
def should_auto_activate(self) -> bool:
return self._auto_activate
Expand Down
62 changes: 62 additions & 0 deletions flytekit/models/concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from flyteidl.admin import launch_plan_pb2 as _launch_plan_idl

from flytekit.models import common as _common


class ConcurrencyLimitBehavior(object):
SKIP = _launch_plan_idl.CONCURRENCY_LIMIT_BEHAVIOR_SKIP

@classmethod
def enum_to_string(cls, val):
"""
:param int val:
:rtype: Text
"""
if val == cls.SKIP:
return "SKIP"
else:
return "<UNKNOWN>"


class ConcurrencyPolicy(_common.FlyteIdlEntity):
"""
Defines the concurrency policy for a launch plan.
"""

def __init__(self, max_concurrency: int, behavior: ConcurrencyLimitBehavior = None):
self._max_concurrency = max_concurrency
self._behavior = behavior if behavior is not None else ConcurrencyLimitBehavior.SKIP

@property
def max_concurrency(self) -> int:
"""
Maximum number of concurrent workflows allowed.
"""
return self._max_concurrency

@property
def behavior(self) -> ConcurrencyLimitBehavior:
"""
Policy behavior when concurrency limit is reached.
"""
return self._behavior

def to_flyte_idl(self) -> _launch_plan_idl.ConcurrencyPolicy:
"""
:rtype: flyteidl.admin.launch_plan_pb2.ConcurrencyPolicy
"""
return _launch_plan_idl.ConcurrencyPolicy(
max=self.max_concurrency,
behavior=self.behavior,
)

@classmethod
def from_flyte_idl(cls, pb2_object: _launch_plan_idl.ConcurrencyPolicy) -> "ConcurrencyPolicy":
"""
:param flyteidl.admin.launch_plan_pb2.ConcurrencyPolicy pb2_object:
:rtype: ConcurrencyPolicy
"""
return cls(
max_concurrency=pb2_object.max,
behavior=pb2_object.behavior,
)
17 changes: 17 additions & 0 deletions flytekit/models/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.models import literals as _literals
from flytekit.models import schedule as _schedule
from flytekit.models import security
from flytekit.models.concurrency import ConcurrencyPolicy
from flytekit.models.core import identifier as _identifier


Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
max_parallelism: typing.Optional[int] = None,
security_context: typing.Optional[security.SecurityContext] = None,
overwrite_cache: typing.Optional[bool] = None,
concurrency_policy: typing.Optional[ConcurrencyPolicy] = None,
):
"""
The spec for a Launch Plan.
Expand All @@ -158,6 +160,8 @@ def __init__(
parallelism/concurrency of MapTasks is independent from this.
:param security_context: This can be used to add security information to a LaunchPlan, which will be used by
every execution
:param flytekit.models.concurrency.ConcurrencyPolicy concurrency_policy:
Concurrency settings to control the number of concurrent workflows in a given LaunchPlan
"""
self._workflow_id = workflow_id
self._entity_metadata = entity_metadata
Expand All @@ -170,6 +174,7 @@ def __init__(
self._max_parallelism = max_parallelism
self._security_context = security_context
self._overwrite_cache = overwrite_cache
self._concurrency_policy = concurrency_policy

@property
def workflow_id(self):
Expand Down Expand Up @@ -246,6 +251,14 @@ def security_context(self) -> typing.Optional[security.SecurityContext]:
def overwrite_cache(self) -> typing.Optional[bool]:
return self._overwrite_cache

@property
def concurrency_policy(self) -> typing.Optional[ConcurrencyPolicy]:
"""
Concurrency settings for the launch plan.
:rtype: flytekit.models.concurrency.ConcurrencyPolicy
"""
return self._concurrency_policy

def to_flyte_idl(self):
"""
:rtype: flyteidl.admin.launch_plan_pb2.LaunchPlanSpec
Expand All @@ -262,6 +275,7 @@ def to_flyte_idl(self):
max_parallelism=self.max_parallelism,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
overwrite_cache=self.overwrite_cache if self.overwrite_cache else None,
concurrency_policy=self.concurrency_policy.to_flyte_idl() if self.concurrency_policy else None,
)

@classmethod
Expand Down Expand Up @@ -295,6 +309,9 @@ def from_flyte_idl(cls, pb2):
if pb2.security_context
else None,
overwrite_cache=pb2.overwrite_cache if pb2.overwrite_cache else None,
concurrency_policy=ConcurrencyPolicy.from_flyte_idl(pb2.concurrency_policy)
if pb2.HasField("concurrency_policy")
else None,
)


Expand Down
8 changes: 8 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.admin.workflow import WorkflowSpec
from flytekit.models.concurrency import ConcurrencyPolicy
from flytekit.models.core import identifier as _identifier_model
from flytekit.models.core import workflow as _core_wf
from flytekit.models.core import workflow as workflow_model
Expand Down Expand Up @@ -358,6 +359,12 @@ def get_serializable_launch_plan(
else:
lc = None

concurrency_policy = None
if entity.concurrency is not None:
concurrency_policy = ConcurrencyPolicy(
max_concurrency=entity.concurrency.max_concurrency, behavior=entity.concurrency.behavior
)

lps = _launch_plan_models.LaunchPlanSpec(
workflow_id=wf_id,
entity_metadata=_launch_plan_models.LaunchPlanMetadata(
Expand All @@ -374,6 +381,7 @@ def get_serializable_launch_plan(
max_parallelism=options.max_parallelism or entity.max_parallelism,
security_context=options.security_context or entity.security_context,
overwrite_cache=options.overwrite_cache or entity.overwrite_cache,
concurrency_policy=concurrency_policy,
)

lp_id = _identifier_model.Identifier(
Expand Down
52 changes: 52 additions & 0 deletions tests/flytekit/unit/core/test_launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from flytekit.models.core import execution as _execution_model
from flytekit.models.core import identifier as identifier_models
from flytekit.tools.translator import get_serializable
from flytekit.models.concurrency import ConcurrencyPolicy, ConcurrencyLimitBehavior

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
Expand Down Expand Up @@ -472,3 +473,54 @@ def wf_with_default_options(a: int) -> int:
assert lp.labels.values["label"] == "foo"
assert len(lp.annotations.values) == 1
assert lp.annotations.values["anno"] == "bar"


def test_lp_with_concurrency():
@task
def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
a = a + 2
return a, "world-" + str(a)

@workflow
def wf(a: int, c: str) -> (int, str):
x, y = t1(a=a)
return x, y

# Test creation with concurrency policy
concurrency_policy = ConcurrencyPolicy(max_concurrency=1, behavior=ConcurrencyLimitBehavior.SKIP)

lp = launch_plan.LaunchPlan.get_or_create(
workflow=wf,
name="concurrency_test_lp",
default_inputs={"a": 3},
fixed_inputs={"c": "4"},
concurrency=concurrency_policy
)

# Verify concurrency policy was set
assert lp.concurrency is not None
assert lp.concurrency.max_concurrency == 1
assert lp.concurrency.behavior == ConcurrencyLimitBehavior.SKIP

# Verify that we can get the same LP back
lp2 = launch_plan.LaunchPlan.get_or_create(
workflow=wf,
name="concurrency_test_lp",
default_inputs={"a": 3},
fixed_inputs={"c": "4"},
concurrency=concurrency_policy
)
assert lp is lp2

# Test with a different concurrency policy
different_policy = ConcurrencyPolicy(max_concurrency=2, behavior=ConcurrencyLimitBehavior.SKIP)

# This should raise an AssertionError due to different concurrency policy
with pytest.raises(AssertionError):
launch_plan.LaunchPlan.get_or_create(
workflow=wf,
name="concurrency_test_lp",
default_inputs={"a": 3},
fixed_inputs={"c": "4"},
concurrency=different_policy
)
41 changes: 41 additions & 0 deletions tests/flytekit/unit/models/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from flyteidl.admin import launch_plan_pb2 as _launch_plan_idl

from flytekit.models.concurrency import ConcurrencyLimitBehavior, ConcurrencyPolicy


def test_concurrency_limit_behavior():
assert ConcurrencyLimitBehavior.SKIP == _launch_plan_idl.CONCURRENCY_LIMIT_BEHAVIOR_SKIP

# Test enum to string conversion
assert ConcurrencyLimitBehavior.enum_to_string(ConcurrencyLimitBehavior.SKIP) == "SKIP"
assert ConcurrencyLimitBehavior.enum_to_string(999) == "<UNKNOWN>"


def test_concurrency_policy_serialization():
policy = ConcurrencyPolicy(max_concurrency=1, behavior=ConcurrencyLimitBehavior.SKIP)

assert policy.max_concurrency == 1
assert policy.behavior == ConcurrencyLimitBehavior.SKIP

# Test serialization to protobuf
pb = policy.to_flyte_idl()
assert isinstance(pb, _launch_plan_idl.ConcurrencyPolicy)
assert pb.max == 1
assert pb.behavior == _launch_plan_idl.CONCURRENCY_LIMIT_BEHAVIOR_SKIP

# Test deserialization from protobuf
policy2 = ConcurrencyPolicy.from_flyte_idl(pb)
assert policy2.max_concurrency == 1
assert policy2.behavior == ConcurrencyLimitBehavior.SKIP


def test_concurrency_policy_with_different_max():
# Test with a higher max value
policy = ConcurrencyPolicy(max_concurrency=5, behavior=ConcurrencyLimitBehavior.SKIP)
assert policy.max_concurrency == 5

pb = policy.to_flyte_idl()
assert pb.max == 5

policy2 = ConcurrencyPolicy.from_flyte_idl(pb)
assert policy2.max_concurrency == 5
Loading
Loading