Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Changes are grouped as follows
- `Fixed` for any bug fixes.
- `Security` in case of vulnerabilities.

## 7.11.6

### Added
* In the `unstable` package: Add TaskThrottle helper class for limiting concurrent task execution with decorator and context manager support

## 7.11.5

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion cognite/extractorutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Cognite extractor utils is a Python package that simplifies the development of new extractors.
"""

__version__ = "7.11.5"
__version__ = "7.11.6"
from .base import Extractor

__all__ = ["Extractor"]
72 changes: 72 additions & 0 deletions cognite/extractorutils/unstable/core/throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Module containing the helper class for Throttling.
"""

from collections.abc import Callable, Generator
from contextlib import contextmanager
from functools import wraps
from threading import Semaphore
from typing import ParamSpec, TypeVar

P = ParamSpec("P")
T = TypeVar("T")


class TaskThrottle:
"""
A throttle to limit the number of concurrent tasks using semaphores.

Usage:
As a decorator:
>>> throttle = TaskThrottle(max_concurrent=5)
>>> @throttle.limit
... def my_task(data):
... # Process data
... pass

As a context manager:
>>> throttle = TaskThrottle(max_concurrent=5)
>>> with throttle.lease():
... # Protected code block
... pass
"""

def __init__(self, max_concurrent: int) -> None:
"""
Create a throttle with specified concurrency limit.

Args:
max_concurrent: Maximum number of tasks that can run concurrently
"""
if max_concurrent < 1:
raise ValueError("max_concurrent must be at least 1")
self._semaphore: Semaphore = Semaphore(max_concurrent)
self._max_concurrent: int = max_concurrent

def limit(self, func: Callable[P, T]) -> Callable[P, T]:
"""
Decorator to throttle a task function.
"""

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
with self.lease():
return func(*args, **kwargs)

Comment on lines +46 to +55
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In light of the previous comment, I would suggest you only implement this limit functionality:

def limit_concurrency(max_concurrent):
    semaphore = Semaphore(max_concurrent)

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            with semaphore:
                return func(*args, **kwargs)
        return wrapper

    return decorator

Which would allow both single-use:

@limit_concurrency(3)
def foo(data):
    ...

...and shared-pool limit:

throttle = limit_concurrency(5)

@throttle
def foo(user_id):
    ...

@throttle
def bar(entry):
    ...

return wrapper

@contextmanager
def lease(self) -> Generator[None, None, None]:
"""
Context manager that acquires/releases a throttle slot.
"""
self._semaphore.acquire()
try:
yield
finally:
self._semaphore.release()
Comment on lines +58 to +67
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using the @contextmanager decorator, you can get the effect of a context manager without writing a full class. You, however, have written a full class, then I see no point in not implementing enter and exit dunder methods:

def __enter__(self):
    self._semaphore.acquire()

def __exit__(self, exc_type, exc_val, exc_tb):
    self._semaphore.release()

However, taking a step back, this is exactly the interface the semaphore already provides you, leading me to question why you need this in the first place?

throttle = TaskThrottle(5)
with throttle.lease():
    ...

throttle = Semaphore(5)
with throttle:
    ...


@property
def max_concurrent(self) -> int:
"""Get the configured concurrency limit."""
return self._max_concurrent
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cognite-extractor-utils"
version = "7.11.5"
version = "7.11.6"
description = "Utilities for easier development of extractors for CDF"
authors = [
{name = "Mathias Lohne", email = "mathias.lohne@cognite.com"}
Expand Down
95 changes: 95 additions & 0 deletions tests/test_unstable/test_throttle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock

import pytest

from cognite.extractorutils.unstable.core.throttle import TaskThrottle


def test_throttle_initialization() -> None:
"""Test throttle initialization with valid and invalid parameters."""

throttle = TaskThrottle(max_concurrent=5)
assert throttle.max_concurrent == 5

with pytest.raises(ValueError, match="max_concurrent must be at least 1"):
TaskThrottle(max_concurrent=0)

with pytest.raises(ValueError, match="max_concurrent must be at least 1"):
TaskThrottle(max_concurrent=-1)


def test_throttle_concurrency_limits() -> None:
max_concurrent = 3
throttle = TaskThrottle(max_concurrent=max_concurrent)

concurrent_count = 0
max_observed = 0
lock = Lock()

def task(task_id: int) -> int:
nonlocal concurrent_count, max_observed

with throttle.lease():
with lock:
concurrent_count += 1
max_observed = max(max_observed, concurrent_count)

time.sleep(0.1)

with lock:
concurrent_count -= 1

return task_id

with ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(task, i) for i in range(10)]
results = [f.result() for f in as_completed(futures)]

assert len(results) == 10
assert max_observed <= max_concurrent


def test_throttle_serial_execution() -> None:
lock = Lock()
throttle_serial = TaskThrottle(max_concurrent=1)
execution_order = []

def serial_task(task_id: int) -> None:
with throttle_serial.lease():
with lock:
execution_order.append(task_id)
time.sleep(0.05)
with lock:
execution_order.append(task_id)

with ThreadPoolExecutor(max_workers=3) as executor:
futures = [executor.submit(serial_task, i) for i in range(3)]
for f in as_completed(futures):
f.result()

for i in range(0, len(execution_order) - 1, 2):
task_id = execution_order[i]
assert execution_order[i + 1] == task_id


def test_throttle_high_concurrency() -> None:
lock = Lock()
throttle_high = TaskThrottle(max_concurrent=50)
completed = []

def fast_task(task_id: int) -> int:
with throttle_high.lease():
time.sleep(0.01)
with lock:
completed.append(task_id)
return task_id

num_tasks = 100
with ThreadPoolExecutor(max_workers=num_tasks) as executor:
futures = [executor.submit(fast_task, i) for i in range(num_tasks)]
for f in as_completed(futures):
f.result()

assert len(completed) == num_tasks
Loading