Skip to content

Commit a551ba5

Browse files
committed
add more tools
1 parent 4aa1217 commit a551ba5

File tree

12 files changed

+555
-134
lines changed

12 files changed

+555
-134
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ authors = [
88
]
99
requires-python = ">=3.13"
1010
dependencies = [
11+
"beartype>=0.22.9",
12+
"more-itertools>=10.8.0",
1113
"trycast>=1.2.0",
1214
]
1315

src/luxtools/__init__.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,22 @@
1-
from .functional.compose import chain
2-
from .functional.overload import overload
3-
from .functional.partial import partial
4-
from .scientific.error_propagation import get_error
5-
from .scientific.printing import NumericResult
1+
from beartype.claw import (
2+
beartype_this_package,
3+
)
4+
5+
beartype_this_package()
66

7-
__all__ = ["chain", "partial", "overload", "get_error", "NumericResult"]
7+
from .functional import (
8+
cascade_filter,
9+
cascade_filter_safe,
10+
chain,
11+
chain_eager,
12+
first,
13+
first_safe,
14+
fnot,
15+
head,
16+
identity,
17+
isempty,
18+
overload,
19+
partial,
20+
)
21+
from .scientific.error_propagation import get_error
22+
from .scientific.printing import NumericResult
Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
1-
from .compose import chain
2-
from .overload import overload
1+
from .basics import first, first_safe, fnot, head, identity, isempty
2+
from .cascade_filter import cascade_filter, cascade_filter_safe
3+
from .compose import chain, chain_eager
4+
from .currying import overload
35
from .partial import partial
46

5-
all = ["chain", "partial", "overload"]
7+
# all = ["chain", "partial", "overload", "cascade_filter", "cascade_filter_safe"]
8+
all = [
9+
"first",
10+
"first_safe",
11+
"fnot",
12+
"head",
13+
"identity",
14+
"isempty",
15+
"cascade_filter",
16+
"cascade_filter_safe",
17+
"chain",
18+
"chain_eager",
19+
"overload",
20+
"partial",
21+
]

src/luxtools/functional/basics.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from collections.abc import Iterable
2+
from typing import Iterator, TypeVar
3+
4+
from more_itertools import peekable
5+
6+
from .currying import overload
7+
8+
T = TypeVar("T")
9+
10+
11+
def identity(x: T) -> T:
12+
return x
13+
14+
15+
def fnot(x: bool):
16+
return not x
17+
18+
19+
def first(items: Iterable[T]):
20+
return items.__iter__().__next__()
21+
22+
23+
def first_safe(items: Iterable[T]):
24+
try:
25+
return items.__iter__().__next__()
26+
except StopIteration:
27+
return None
28+
29+
30+
head = first
31+
32+
33+
@overload()
34+
def isempty(items: Iterator):
35+
it = peekable(items)
36+
try:
37+
it.peek()
38+
return False
39+
except StopIteration:
40+
return True
41+
42+
43+
@overload()
44+
def isempty(items: Iterable):
45+
it = peekable(items)
46+
try:
47+
it.peek()
48+
return False
49+
except StopIteration:
50+
return True
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from collections.abc import Callable, Iterable, Iterator
2+
from itertools import tee
3+
from typing import TypeVar
4+
5+
from .basics import first, fnot
6+
from .compose import chain as fchain
7+
8+
T = TypeVar("T")
9+
10+
11+
def recursive_tee(iterable, n=2):
12+
"""
13+
A deep, recursive version of itertools.tee.
14+
15+
It splits 'iterable' into 'n' independent iterators.
16+
Crucially, if it encounters an item that is itself an Iterator,
17+
it recursively 'tees' that item into 'n' independent copies
18+
before yielding it.
19+
"""
20+
# 1. We must work with an iterator to use tee
21+
it = iter(iterable)
22+
23+
# 2. Define a generator that processes items as they pass through.
24+
# This is the "Lazy Mapper" that handles the splitting logic.
25+
def splitter():
26+
for item in it:
27+
if isinstance(item, Iterator):
28+
# RECURSION: The item is an iterator, so we must split it
29+
# into n copies (one for each branch we are creating)
30+
yield recursive_tee(item, n)
31+
else:
32+
# BASE CASE: The item is simple (int, string, etc).
33+
# Just return n references to the same object.
34+
yield (item,) * n
35+
36+
# 3. Create a master stream that yields tuples of split items:
37+
# e.g. ( (item1_copyA, item1_copyB), (item2_copyA, item2_copyB), ... )
38+
# We tee this master stream so every output branch can access the tuples.
39+
streams_of_tuples = tee(splitter(), n)
40+
41+
# 4. Unzip: Create n generators.
42+
# The i-th generator picks the i-th element from the stream of tuples.
43+
return tuple((t[i] for t in stream) for i, stream in enumerate(streams_of_tuples))
44+
45+
46+
def dropwhile_safe(predicate, iterable):
47+
"""
48+
A safe, deep-aware implementation of dropwhile.
49+
"""
50+
# We iterate over the main sequence
51+
outer_iter = iter(iterable)
52+
53+
for item in outer_iter:
54+
# 1. RECURSIVE SPLIT:
55+
# Use recursive_tee instead of standard tee.
56+
# If 'item' is a nested iterator tree, both copies are fully independent.
57+
check_copy, keep_copy = recursive_tee(item, 2)
58+
59+
# 2. CHECK:
60+
# Predicate can consume 'check_copy' arbitrarily (even deeply).
61+
if predicate(check_copy):
62+
continue # Drop and move to next
63+
64+
# 3. YIELD & RESUME:
65+
# Yield the pristine 'keep_copy' and the rest of the outer stream.
66+
yield keep_copy
67+
yield from outer_iter
68+
return
69+
70+
71+
def takefirst(predicate: Callable[[T], bool], items: Iterable[T]):
72+
return first(iter(dropwhile_safe(fchain([fnot, predicate]), iter(items))))
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from collections.abc import Callable, Iterable
2+
from typing import TypeVar
3+
4+
from .basics import fnot, isempty
5+
from .better_builtins import takefirst
6+
from .compose import chain as fchain
7+
8+
T = TypeVar("T")
9+
10+
11+
class NoItemsMatchedError(ValueError):
12+
pass
13+
14+
15+
def cascade_filter(
16+
filter_predicates: Iterable[Callable[[T], bool]],
17+
items: Iterable[T],
18+
cascade_predicate: Callable[[Iterable[T]], bool] = isempty,
19+
) -> Iterable[T]:
20+
"""
21+
Apply a cascade of filters to an iterable of items.
22+
23+
Each predicate in `filter_predicates` is applied to `items`, producing a list of
24+
matching items for that predicate. The first list that satisfies
25+
`cascade_predicate` (default: any items matched) is returned.
26+
27+
Parameters
28+
----------
29+
filter_predicates : Iterable[Callable[[T], bool]]
30+
An iterable of predicate functions. Each function takes an item of type T
31+
and returns True if the item should be kept.
32+
items : Iterable[T]
33+
The iterable of items to filter.
34+
cascade_predicate : Callable[[Iterable[T]], bool], optional
35+
A predicate applied to each filtered list. The first list that makes this
36+
predicate return True is returned. Defaults to lambda x: any(True for _ in x)
37+
(i.e., returns the first non-empty filtered list).
38+
39+
Returns
40+
-------
41+
Iterable[T]
42+
The first filtered list that satisfies `cascade_predicate`.
43+
44+
Raises
45+
------
46+
NoItemsMatchedError
47+
If none of the filtered lists satisfy `cascade_predicate`.
48+
"""
49+
filtered = map(lambda fc: filter(fc, items), filter_predicates)
50+
51+
try:
52+
# get the first match
53+
return takefirst(fchain([fnot, cascade_predicate]), iter(filtered))
54+
except StopIteration:
55+
raise NoItemsMatchedError(
56+
f"No items satisfies the filters. Tried {len(list(filter_predicates))} filter(s)."
57+
)
58+
59+
60+
def cascade_filter_safe(
61+
filter_predicates: Iterable[Callable[[T], bool]],
62+
items: Iterable[T],
63+
cascade_predicate: Callable[[Iterable[T]], bool] = isempty,
64+
) -> Iterable[T]:
65+
"""
66+
See `cascade_filter` for documentation
67+
"""
68+
try:
69+
return cascade_filter(filter_predicates, items, cascade_predicate)
70+
except NoItemsMatchedError:
71+
return []
72+
73+
74+
if __name__ == "__main__":
75+
items = range(10)
76+
77+
filters = [
78+
lambda x: x > 20, # will fail
79+
lambda x: x > 5, # will work
80+
lambda x: x > 1, # will work
81+
]
82+
83+
filters_bad = [
84+
lambda x: x > 20, # will fail
85+
lambda x: x > 25,
86+
lambda x: x > 100,
87+
]
88+
89+
print(list(cascade_filter_safe(filters, items)))
90+
91+
# print(first(items))

0 commit comments

Comments
 (0)