Skip to content

Commit 57c7c7e

Browse files
Eager rate limit (#3133)
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
1 parent 1dadf0b commit 57c7c7e

File tree

3 files changed

+114
-2
lines changed

3 files changed

+114
-2
lines changed

flytekit/core/worker_queue.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from flytekit.loggers import developer_logger, logger
2424
from flytekit.models.common import Labels
2525
from flytekit.models.core.execution import WorkflowExecutionPhase
26+
from flytekit.utils.rate_limiter import RateLimiter
2627

2728
if typing.TYPE_CHECKING:
2829
from flytekit.remote.remote_callable import RemoteEntity
@@ -185,6 +186,7 @@ def __init__(self, remote: FlyteRemote, ss: SerializationSettings, tag: str, roo
185186
)
186187
self.__runner_thread.start()
187188
atexit.register(self._close, stopping_condition=self.stopping_condition, runner=self.__runner_thread)
189+
self.rate_limiter = RateLimiter(rpm=60)
188190

189191
# Executions should be tracked in the following way:
190192
# a) you should be able to list by label, all executions generated by the current eager task,
@@ -219,7 +221,7 @@ def reconcile_one(self, update: Update):
219221
try:
220222
item = update.work_item
221223
if item.wf_exec is None:
222-
logger.warning(f"reconcile should launch for {id(item)} entity name: {item.entity.name}")
224+
logger.info(f"reconcile should launch for {id(item)} entity name: {item.entity.name}")
223225
wf_exec = self.launch_execution(update.work_item, update.idx)
224226
update.wf_exec = wf_exec
225227
# Set this to running even if the launched execution was a re-run and already succeeded.
@@ -355,7 +357,7 @@ def get_execution_name(self, entity: RunnableEntity, idx: int, input_kwargs: dic
355357

356358
def launch_execution(self, wi: WorkItem, idx: int) -> FlyteWorkflowExecution:
357359
"""This function launches executions."""
358-
logger.warning(f"Launching execution for {wi.entity.name} {idx=} with {wi.input_kwargs}")
360+
logger.info(f"Launching execution for {wi.entity.name} {idx=} with {wi.input_kwargs}")
359361
if wi.result is None and wi.error is None:
360362
l = self.get_labels()
361363
e = self.get_env()
@@ -370,6 +372,7 @@ def launch_execution(self, wi: WorkItem, idx: int) -> FlyteWorkflowExecution:
370372
assert self.ss.version
371373
version = self.ss.version
372374

375+
self.rate_limiter.sync_acquire()
373376
# todo: if the execution already exists, remote.execute will return that execution. in the future
374377
# we can add input checking to make sure the inputs are indeed a match.
375378
wf_exec = self.remote.execute(

flytekit/utils/rate_limiter.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import asyncio
2+
from collections import deque
3+
from datetime import datetime, timedelta
4+
5+
from flytekit.loggers import developer_logger
6+
from flytekit.utils.asyn import run_sync
7+
8+
9+
class RateLimiter:
10+
"""Rate limiter that allows up to a certain number of requests per minute."""
11+
12+
def __init__(self, rpm: int):
13+
if not isinstance(rpm, int) or rpm <= 0 or rpm > 100:
14+
raise ValueError("Rate must be a positive integer between 1 and 100")
15+
self.rpm = rpm
16+
self.queue = deque()
17+
self.sem = asyncio.Semaphore(rpm)
18+
self.delay = timedelta(seconds=60) # always 60 seconds since this we're using a per-minute rate limiter
19+
20+
def sync_acquire(self):
21+
run_sync(self.acquire)
22+
23+
async def acquire(self):
24+
async with self.sem:
25+
now = datetime.now()
26+
# Start by clearing out old data
27+
while self.queue and (now - self.queue[0]) > self.delay:
28+
self.queue.popleft()
29+
30+
# Now that the queue only has valid entries, we'll need to wait if the queue is full.
31+
if len(self.queue) >= self.rpm:
32+
# Compute necessary delay and sleep that amount
33+
# First pop one off, so another coroutine won't try to base its wait time off the same timestamp. But
34+
# if you pop it off, the next time this code runs it'll think there's enough spots... so add the
35+
# expected time back onto the queue before awaiting. Once you await, you lose the 'thread' and other
36+
# coroutines can run.
37+
# Basically the invariant is: this block of code leaves the number of items in the queue unchanged:
38+
# it'll pop off a timestamp but immediately add one back.
39+
# Because of the semaphore, we don't have to worry about the one we add to the end being referenced
40+
# because there will never be more than RPM-1 other coroutines running at the same time.
41+
earliest = self.queue.popleft()
42+
delay: timedelta = (earliest + self.delay) - now
43+
if delay.total_seconds() > 0:
44+
next_time = earliest + self.delay
45+
self.queue.append(next_time)
46+
developer_logger.debug(
47+
f"Capacity reached - removed time {earliest} and added back {next_time}, sleeping for {delay.total_seconds()}"
48+
)
49+
await asyncio.sleep(delay.total_seconds())
50+
else:
51+
developer_logger.debug(f"No more need to wait, {earliest=} vs {now=}")
52+
self.queue.append(now)
53+
else:
54+
self.queue.append(now)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import sys
3+
import timeit
4+
import asyncio
5+
6+
from datetime import timedelta
7+
from flytekit.utils.rate_limiter import RateLimiter
8+
9+
10+
async def launch_requests(rate_limiter: RateLimiter, total: int):
11+
tasks = [asyncio.create_task(rate_limiter.acquire()) for _ in range(total)]
12+
await asyncio.gather(*tasks)
13+
14+
15+
async def helper_for_async(rpm: int, total: int):
16+
rate_limiter = RateLimiter(rpm=rpm)
17+
rate_limiter.delay = timedelta(seconds=1)
18+
await launch_requests(rate_limiter, total)
19+
20+
21+
def runner_for_async(rpm: int, total: int):
22+
loop = asyncio.get_event_loop()
23+
return loop.run_until_complete(helper_for_async(rpm, total))
24+
25+
26+
@pytest.mark.asyncio
27+
def test_rate_limiter():
28+
elapsed_time = timeit.timeit(lambda: runner_for_async(2, 2), number=1)
29+
elapsed_time_more = timeit.timeit(lambda: runner_for_async(2, 6), number=1)
30+
assert elapsed_time < 0.25
31+
assert round(elapsed_time_more) == 2
32+
33+
34+
async def sync_wrapper(rate_limiter: RateLimiter):
35+
rate_limiter.sync_acquire()
36+
37+
38+
async def helper_for_sync(rpm: int, total: int):
39+
rate_limiter = RateLimiter(rpm=rpm)
40+
rate_limiter.delay = timedelta(seconds=1)
41+
tasks = [asyncio.create_task(sync_wrapper(rate_limiter)) for _ in range(total)]
42+
await asyncio.gather(*tasks)
43+
44+
45+
def runner_for_sync(rpm: int, total: int):
46+
loop = asyncio.get_event_loop()
47+
return loop.run_until_complete(helper_for_sync(rpm, total))
48+
49+
50+
@pytest.mark.asyncio
51+
def test_rate_limiter_s():
52+
elapsed_time = timeit.timeit(lambda: runner_for_sync(2, 2), number=1)
53+
elapsed_time_more = timeit.timeit(lambda: runner_for_sync(2, 6), number=1)
54+
assert elapsed_time < 0.25
55+
assert round(elapsed_time_more) == 2

0 commit comments

Comments
 (0)