Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ __pycache__
.mypy_cache
.pytest_cache
*.egg-info
.coverage
.coverage
.DS_Store
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ authors = [
]
requires-python = ">=3.13"
dependencies = [
"beartype>=0.22.9",
"more-itertools>=10.8.0",
"trycast>=1.2.0",
]

Expand Down
27 changes: 21 additions & 6 deletions src/luxtools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from .functional.compose import chain
from .functional.overload import overload
from .functional.partial import partial
from .scientific.error_propagation import get_error
from .scientific.printing import NumericResult
from beartype.claw import (
beartype_this_package,
)

beartype_this_package()

__all__ = ["chain", "partial", "overload", "get_error", "NumericResult"]
from .functional import (
cascade_filter,
cascade_filter_safe,
chain,
chain_eager,
first,
first_safe,
fnot,
head,
identity,
isempty,
overload,
partial,
)
from .scientific.error_propagation import get_error
from .scientific.printing import NumericResult
22 changes: 19 additions & 3 deletions src/luxtools/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
from .compose import chain
from .overload import overload
from .basics import first, first_safe, fnot, head, identity, isempty
from .cascade_filter import cascade_filter, cascade_filter_safe
from .compose import chain, chain_eager
from .currying import overload
from .partial import partial

all = ["chain", "partial", "overload"]
# all = ["chain", "partial", "overload", "cascade_filter", "cascade_filter_safe"]
all = [
"first",
"first_safe",
"fnot",
"head",
"identity",
"isempty",
"cascade_filter",
"cascade_filter_safe",
"chain",
"chain_eager",
"overload",
"partial",
]
50 changes: 50 additions & 0 deletions src/luxtools/functional/basics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from collections.abc import Iterable
from typing import Iterator, TypeVar

from more_itertools import peekable

from .currying import overload

T = TypeVar("T")


def identity(x: T) -> T:
return x


def fnot(x: bool):
return not x


def first(items: Iterable[T]):
return items.__iter__().__next__()


def first_safe(items: Iterable[T]):
try:
return items.__iter__().__next__()
except StopIteration:
return None


head = first


@overload()
def isempty(items: Iterator):
it = peekable(items)
try:
it.peek()
return False
except StopIteration:
return True


@overload()
def isempty(items: Iterable):
it = peekable(items)
try:
it.peek()
return False
except StopIteration:
return True
72 changes: 72 additions & 0 deletions src/luxtools/functional/better_builtins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from collections.abc import Callable, Iterable, Iterator
from itertools import tee
from typing import TypeVar

from .basics import first, fnot
from .compose import chain as fchain

T = TypeVar("T")


def recursive_tee(iterable, n=2):
"""
A deep, recursive version of itertools.tee.

It splits 'iterable' into 'n' independent iterators.
Crucially, if it encounters an item that is itself an Iterator,
it recursively 'tees' that item into 'n' independent copies
before yielding it.
"""
# 1. We must work with an iterator to use tee
it = iter(iterable)

# 2. Define a generator that processes items as they pass through.
# This is the "Lazy Mapper" that handles the splitting logic.
def splitter():
for item in it:
if isinstance(item, Iterator):
# RECURSION: The item is an iterator, so we must split it
# into n copies (one for each branch we are creating)
yield recursive_tee(item, n)
else:
# BASE CASE: The item is simple (int, string, etc).
# Just return n references to the same object.
yield (item,) * n

# 3. Create a master stream that yields tuples of split items:
# e.g. ( (item1_copyA, item1_copyB), (item2_copyA, item2_copyB), ... )
# We tee this master stream so every output branch can access the tuples.
streams_of_tuples = tee(splitter(), n)

# 4. Unzip: Create n generators.
# The i-th generator picks the i-th element from the stream of tuples.
return tuple((t[i] for t in stream) for i, stream in enumerate(streams_of_tuples))


def dropwhile_safe(predicate, iterable):
"""
A safe, deep-aware implementation of dropwhile.
"""
# We iterate over the main sequence
outer_iter = iter(iterable)

for item in outer_iter:
# 1. RECURSIVE SPLIT:
# Use recursive_tee instead of standard tee.
# If 'item' is a nested iterator tree, both copies are fully independent.
check_copy, keep_copy = recursive_tee(item, 2)

# 2. CHECK:
# Predicate can consume 'check_copy' arbitrarily (even deeply).
if predicate(check_copy):
continue # Drop and move to next

# 3. YIELD & RESUME:
# Yield the pristine 'keep_copy' and the rest of the outer stream.
yield keep_copy
yield from outer_iter
return


def takefirst(predicate: Callable[[T], bool], items: Iterable[T]):
return first(iter(dropwhile_safe(fchain([fnot, predicate]), iter(items))))
91 changes: 91 additions & 0 deletions src/luxtools/functional/cascade_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from collections.abc import Callable, Iterable
from typing import TypeVar

from .basics import fnot, isempty
from .better_builtins import takefirst
from .compose import chain as fchain

T = TypeVar("T")


class NoItemsMatchedError(ValueError):
pass


def cascade_filter(
filter_predicates: Iterable[Callable[[T], bool]],
items: Iterable[T],
cascade_predicate: Callable[[Iterable[T]], bool] = isempty,
) -> Iterable[T]:
"""
Apply a cascade of filters to an iterable of items.

Each predicate in `filter_predicates` is applied to `items`, producing a list of
matching items for that predicate. The first list that satisfies
`cascade_predicate` (default: any items matched) is returned.

Parameters
----------
filter_predicates : Iterable[Callable[[T], bool]]
An iterable of predicate functions. Each function takes an item of type T
and returns True if the item should be kept.
items : Iterable[T]
The iterable of items to filter.
cascade_predicate : Callable[[Iterable[T]], bool], optional
A predicate applied to each filtered list. The first list that makes this
predicate return True is returned. Defaults to lambda x: any(True for _ in x)
(i.e., returns the first non-empty filtered list).

Returns
-------
Iterable[T]
The first filtered list that satisfies `cascade_predicate`.

Raises
------
NoItemsMatchedError
If none of the filtered lists satisfy `cascade_predicate`.
"""
filtered = map(lambda fc: filter(fc, items), filter_predicates)

try:
# get the first match
return takefirst(fchain([fnot, cascade_predicate]), iter(filtered))
except StopIteration:
raise NoItemsMatchedError(
f"No items satisfies the filters. Tried {len(list(filter_predicates))} filter(s)."
)


def cascade_filter_safe(
filter_predicates: Iterable[Callable[[T], bool]],
items: Iterable[T],
cascade_predicate: Callable[[Iterable[T]], bool] = isempty,
) -> Iterable[T]:
"""
See `cascade_filter` for documentation
"""
try:
return cascade_filter(filter_predicates, items, cascade_predicate)
except NoItemsMatchedError:
return []


if __name__ == "__main__":
items = range(10)

filters = [
lambda x: x > 20, # will fail
lambda x: x > 5, # will work
lambda x: x > 1, # will work
]

filters_bad = [
lambda x: x > 20, # will fail
lambda x: x > 25,
lambda x: x > 100,
]

print(list(cascade_filter_safe(filters, items)))

# print(first(items))
Loading