Skip to content

Commit 09b80e3

Browse files
author
Luke Shaw
committed
Added []-indexing to string lazyexprs
1 parent 548395d commit 09b80e3

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

src/blosc2/lazyexpr.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,8 @@ def compute_smaller_slice(larger_shape, smaller_shape, larger_slice):
541541

542542
# Define the patterns for validation
543543
validation_patterns = [
544-
r"[\;\[\:]", # Flow control characters
544+
# r"[\;\[\:]", # Flow control characters
545+
r"[\;]", # Flow control characters
545546
r"(^|[^\w])__[\w]+__($|[^\w])", # Dunder methods
546547
r"\.\b(?!real|imag|(\d*[eE]?[+-]?\d+)|(\d*[eE]?[+-]?\d+j)|\d*j\b|(sum|prod|min|max|std|mean|var|any|all|where)"
547548
r"\s*\([^)]*\)|[a-zA-Z_]\w*\s*\([^)]*\))", # Attribute patterns
@@ -592,7 +593,7 @@ def validate_expr(expr: str) -> None:
592593
raise ValueError(f"'{expr}' is not a valid expression.")
593594

594595
# Check for invalid characters not covered by the tokenizer
595-
invalid_chars = re.compile(r"[^\w\s+\-*/%().,=<>!&|~^]")
596+
invalid_chars = re.compile(r"[^\w\s+\-*/%()[].,=<>!&|~^]")
596597
if invalid_chars.search(skip_quotes) is not None:
597598
invalid_chars = invalid_chars.findall(skip_quotes)
598599
raise ValueError(f"Expression {expr} contains invalid characters: {invalid_chars}")
@@ -744,6 +745,38 @@ def visit_Call(self, node):
744745
return newexpression, newoperands
745746

746747

748+
def convert_to_slice(expression):
749+
"""
750+
Assumes all operands are of the form o...
751+
Parameters
752+
----------
753+
expression: str
754+
755+
Returns
756+
-------
757+
new_expr : str
758+
"""
759+
760+
new_expr = ""
761+
skip_to_char = 0
762+
for i, expr_i in enumerate(expression):
763+
if i < skip_to_char:
764+
continue
765+
if expr_i == "[":
766+
k = expression[i:].find("]") # start checking from after [
767+
slice_convert = expression[i : i + k + 1] # include [ and ]
768+
slicer = eval(f"np.s_{slice_convert}")
769+
slicer = (slicer,) if isinstance(slicer, slice) else slicer # standardise to tuple
770+
if any(isinstance(el, str) for el in slicer): # handle fields
771+
raise ValueError("Cannot handle fields for slicing lazy expressions.")
772+
slicer = str(slicer)
773+
new_expr += ".slice(" + slicer + ")"
774+
skip_to_char = i + k + 1
775+
else:
776+
new_expr += expr_i
777+
return new_expr
778+
779+
747780
class TransformNumpyCalls(ast.NodeTransformer):
748781
def __init__(self):
749782
self.replacements = {}
@@ -2780,6 +2813,7 @@ def save(self, urlpath=None, **kwargs):
27802813
def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=None):
27812814
# Validate the expression
27822815
validate_expr(expression)
2816+
expression = convert_to_slice(expression)
27832817
if guess:
27842818
# The expression has been validated, so we can evaluate it
27852819
# in guessing mode to avoid computing reductions

tests/ndarray/test_reductions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,25 @@ def test_reduction_index():
430430
assert arr.shape == newarr.shape
431431

432432

433-
def test_slice_in_lazy():
433+
def test_slice_lazy():
434434
shape = (20, 20)
435435
a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape)
436436
arr = blosc2.lazyexpr("anarr.slice(slice(10,15)) + 1", {"anarr": a})
437437
newarr = arr.compute()
438438
np.testing.assert_allclose(newarr[:], a.slice(slice(10, 15))[:] + 1)
439+
440+
441+
def test_slicebrackets_lazy():
442+
shape = (20, 20)
443+
a = blosc2.linspace(0, 20, num=np.prod(shape), shape=shape)
444+
arr = blosc2.lazyexpr("anarr[10:15] + 1", {"anarr": a})
445+
newarr = arr.compute()
446+
np.testing.assert_allclose(newarr[:], a[10:15] + 1)
447+
448+
arr = blosc2.lazyexpr("anarr[10:15, 2:9] + 1", {"anarr": a})
449+
newarr = arr.compute()
450+
np.testing.assert_allclose(newarr[:], a[10:15, 2:9] + 1)
451+
452+
arr = blosc2.lazyexpr("anarr[10:15][2:9] + 1", {"anarr": a})
453+
newarr = arr.compute()
454+
np.testing.assert_allclose(newarr[:], a[10:15][2:9] + 1)

0 commit comments

Comments
 (0)