Skip to content

Commit ade63ff

Browse files
authored
LambdaExpression processor implemetation (#136)
* LambdaExpression processor implemetation * Add processor to docs structure Signed-off-by: Sasha Meister <[email protected]> * Remove build folder from repository and add to .gitignore * Fixed SSLCertVerificationError during docs building Signed-off-by: Sasha Meister <[email protected]> * Removed SubRegex changes Signed-off-by: Sasha Meister <[email protected]> * Fix docs formatting Signed-off-by: Sasha Meister <[email protected]> * Added optional class exception to docs build Signed-off-by: Sasha Meister <[email protected]> * Changes addressing the reviewer’s comments Signed-off-by: Sasha Meister <[email protected]> * fix pyarrow test issue Signed-off-by: Sasha Meister <[email protected]> --------- Signed-off-by: Sasha Meister <[email protected]> Signed-off-by: Sasha Meister <[email protected]>
1 parent 7af35c1 commit ade63ff

File tree

8 files changed

+379
-21
lines changed

8 files changed

+379
-21
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@ test-venv
2424
__pycache__
2525

2626
# egg-info
27-
sdp.egg-info
27+
sdp.egg-info
28+
29+
# build
30+
build

docs/src/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,14 @@ def setup(app):
184184

185185
nitpick_ignore = [
186186
('py:class', 'abc.ABC'),
187-
('py:class', 'sdp.processors.base_processor.DataEntry'),
187+
('py:class', 'optional'),
188+
('py:mod', 'sdp.utils.apply_operators'),
188189
]
189190
# nitpick_ignore_regex = [('py:class', '*')]
190191

191192
#adding this especially for coraal, temporary
192193
linkcheck_ignore = [
193194
r'https://lingtools\.uoregon\.edu/coraal/coraal_download_list\.txt',
195+
r'https://ieeexplore\.ieee\.org/document/1326009'
194196
]
195197
# https://lingtools.uoregon.edu/coraal/coraal_download_list.txt

docs/src/sdp/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ Data modifications
219219
.. autodata:: sdp.processors.InverseNormalizeText
220220
:annotation:
221221

222+
.. autodata:: sdp.processors.LambdaExpression
223+
:annotation:
224+
222225
Data filtering
223226
''''''''''''''
224227

sdp/processors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
SubIfASRSubstitution,
107107
SubMakeLowercase,
108108
SubRegex,
109+
LambdaExpression,
109110
)
110111
from sdp.processors.modify_manifest.data_to_dropbool import (
111112
DropASRError,

sdp/processors/datasets/voxpopuli/normalize_from_non_pc_text.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def restore_pc(orig_words, norm_words):
7777
# separately in normalized form, so just removing the comma here
7878
add_punct = ""
7979
if orig_text[idx_orig][0].isdigit() and not orig_text[idx_orig].isdigit():
80-
number, word = re.split('(\d+)', orig_text[idx_orig])[1:]
80+
number, word = re.split(r'(\d+)', orig_text[idx_orig])[1:]
8181
orig_text[idx_orig] = number
8282
if word in string.punctuation:
8383
add_punct = word
@@ -87,7 +87,7 @@ def restore_pc(orig_words, norm_words):
8787
# another annoying case is if typo ends with number like here "dell'11"
8888
# same logic, but need to go back to the first check, so doing "continue" below
8989
if orig_text[idx_orig][-1].isdigit() and not orig_text[idx_orig].isdigit():
90-
word, number = re.split('(\d+)', orig_text[idx_orig])[:-1]
90+
word, number = re.split(r'(\d+)', orig_text[idx_orig])[:-1]
9191
orig_text[idx_orig] = word
9292
orig_text.insert(idx_orig + 1, number)
9393
continue

sdp/processors/modify_manifest/data_to_data.py

Lines changed: 98 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,8 @@
3333
from sdp.utils.common import ffmpeg_convert
3434
from sdp.utils.edit_spaces import add_start_end_spaces, remove_extra_spaces
3535
from sdp.utils.get_diff import get_diff_with_subs_grouped
36-
from sdp.utils.metrics_computation import (
37-
get_cer,
38-
get_charrate,
39-
get_wer,
40-
get_wmr,
41-
get_wordrate,
42-
)
36+
from sdp.utils.metrics_computation import get_wer
37+
from sdp.utils.apply_operators import evaluate_expression
4338

4439

4540
class GetAudioDuration(BaseParallelProcessor):
@@ -1127,3 +1122,99 @@ def process(self):
11271122
if self.failed_files:
11281123
logger.warning(f"Failed to process {len(self.failed_files)} files.")
11291124
logger.debug(f"Failed files: {self.failed_files}")
1125+
1126+
1127+
class LambdaExpression(BaseParallelProcessor):
1128+
"""
1129+
A dataset processor that evaluates a Python expression on each data entry and either stores
1130+
the result in a new field or uses it as a filtering condition.
1131+
1132+
This processor is useful for dynamic field computation or conditional filtering of entries based
1133+
on configurable expressions. It leverages ``evaluate_expression``, which safely evaluates expressions
1134+
using the abstract syntax tree (AST).
1135+
1136+
Filtering behavior:
1137+
If ``filter=True``, the expression is evaluated for each entry. Only entries for which the expression evaluates to ``True`` are kept; all others are filtered out (removed from the output).
1138+
If ``filter=False``, the result of the expression is stored in the field specified by ``new_field`` for each entry (no filtering occurs).
1139+
1140+
Examples::
1141+
1142+
# Example 1: Filtering entries where the duration is greater than 5.0 seconds
1143+
LambdaExpression(
1144+
new_field="keep", # This field is ignored when filter=True
1145+
expression="entry['duration'] > 5.0",
1146+
lambda_param_name="entry",
1147+
filter=True
1148+
)
1149+
# Only entries with duration > 5.0 will be kept in the output manifest.
1150+
1151+
# Example 2: Adding a new field with the number of words in the text
1152+
LambdaExpression(
1153+
new_field="num_words",
1154+
expression="len(entry['text'].split())",
1155+
lambda_param_name="entry",
1156+
filter=False
1157+
)
1158+
# Each entry will have a new field 'num_words' with the word count of the 'text' field.
1159+
1160+
Supported operations:
1161+
1162+
The expression supports a safe subset of Python operations, including:
1163+
1164+
- Arithmetic: ``+``, ``-``, ``*``, ``/``, ``//``, ``%``, ``**``
1165+
- Comparisons: ``==``, ``!=``, ``<``, ``<=``, ``>``, ``>=``, ``is``, ``is not``
1166+
- Logical: ``and``, ``or``, ``not``
1167+
- Bitwise: ``|``, ``&``, ``^``, ``~``, ``<<``, ``>>``
1168+
- Indexing and slicing: ``entry['key']``, ``entry[0]``, ``entry[1:3]``
1169+
- Conditional (ternary) expressions: ``a if cond else b``
1170+
- List and dict literals: ``[a, b]``, ``{k: v}``
1171+
- Attribute access: ``entry.attr``
1172+
- Function calls (limited): ``max``, ``min``, ``len``, ``sum``, ``abs``, ``sorted``
1173+
1174+
For the full list, see the ``OPERATORS`` and ``SAFE_FUNCTIONS`` in :mod:`sdp.utils.apply_operators`.
1175+
See also: https://docs.python.org/3/library/operator.html
1176+
1177+
Args:
1178+
new_field (str): The name of the field to store the result of the expression (ignored if filter=True).
1179+
expression (str): A Python expression to evaluate. It can reference fields of the data entry
1180+
using the name specified by ``lambda_param_name`` (default: 'entry').
1181+
lambda_param_name (str, optional): The name to refer to the current data entry in the expression.
1182+
Default is "entry".
1183+
filter (bool, optional): If True, the expression result is treated as a condition.
1184+
The entry is kept only if the result is ``True``. Default is ``False``.
1185+
**kwargs: Additional keyword arguments passed to the ``BaseParallelProcessor`` class.
1186+
1187+
Returns:
1188+
str: A line-delimited JSON manifest, where each line is a processed entry.
1189+
The result may contain fewer entries than the input if ``filter=True``.
1190+
"""
1191+
def __init__(
1192+
self,
1193+
new_field: str,
1194+
expression: str,
1195+
lambda_param_name: str = "entry",
1196+
filter: bool = False,
1197+
**kwargs,
1198+
):
1199+
super().__init__(**kwargs)
1200+
self.new_field = new_field
1201+
self.expression = expression
1202+
self.lambda_param_name = lambda_param_name
1203+
self.filter = filter
1204+
1205+
def process_dataset_entry(self, data_entry) -> List[DataEntry]:
1206+
"""
1207+
Process a single data entry by evaluating the expression.
1208+
1209+
If `filter` is True, the entry is only retained if the expression evaluates to True.
1210+
Otherwise, the result is stored in `new_field`.
1211+
"""
1212+
value = evaluate_expression(self.expression, data_entry, self.lambda_param_name)
1213+
if self.filter:
1214+
if value is not True:
1215+
return []
1216+
data_entry[self.new_field] = value
1217+
return [DataEntry(data=data_entry)]
1218+
1219+
def finalize(self, metrics):
1220+
super().finalize(metrics)

sdp/utils/apply_operators.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import operator
16+
import ast
17+
import re
18+
from typing import Any, Dict
19+
20+
"""
21+
This module provides a safe evaluator for simple Python expressions using the abstract syntax tree (AST).
22+
It restricts execution to a subset of safe operations (arithmetic, logical, comparisons, indexing, etc.)
23+
and selected built-in functions (e.g., max, min, len), while preventing arbitrary code execution.
24+
25+
Useful in cases where dynamic expressions need to be evaluated using a provided variable context,
26+
such as configuration systems, data transformation pipelines, or manifest filtering.
27+
28+
Functions:
29+
- evaluate_expression: Safely evaluates a Python expression string using restricted AST operations.
30+
"""
31+
32+
OPERATORS = {
33+
ast.Add: operator.add,
34+
ast.Sub: operator.sub,
35+
ast.Mult: operator.mul,
36+
ast.Div: operator.truediv,
37+
ast.FloorDiv: operator.floordiv,
38+
ast.Mod: operator.mod,
39+
ast.Pow: operator.pow,
40+
ast.BitOr: operator.or_,
41+
ast.BitAnd: operator.and_,
42+
ast.BitXor: operator.xor,
43+
ast.LShift: operator.lshift,
44+
ast.RShift: operator.rshift,
45+
ast.Invert: operator.invert,
46+
ast.USub: operator.neg,
47+
ast.UAdd: operator.pos,
48+
ast.Eq: operator.eq,
49+
ast.NotEq: operator.ne,
50+
ast.Lt: operator.lt,
51+
ast.LtE: operator.le,
52+
ast.Gt: operator.gt,
53+
ast.GtE: operator.ge,
54+
ast.Is: operator.is_,
55+
ast.IsNot: operator.is_not,
56+
ast.And: operator.and_,
57+
ast.Or: operator.or_,
58+
ast.Not: operator.not_,
59+
}
60+
61+
SAFE_FUNCTIONS = {
62+
'max': max,
63+
'min': min,
64+
'len': len,
65+
'sum': sum,
66+
'abs': abs,
67+
'sorted': sorted,
68+
}
69+
70+
71+
def evaluate_expression(expression: str, variables: Dict[str, Any] = None, var_prefix: str = None) -> Any:
72+
"""
73+
Safely evaluates a Python expression string using a restricted set of AST nodes and operators.
74+
75+
Args:
76+
expression (str): The expression to evaluate.
77+
variables (Dict[str, Any], optional): A dictionary of variable names and values to use in evaluation.
78+
var_prefix (str, optional): If specified, this prefix will be removed from variable names
79+
in the expression before evaluation.
80+
81+
Returns:
82+
any: The result of evaluating the expression.
83+
84+
Raises:
85+
ValueError: If the expression contains unsupported operations or names.
86+
"""
87+
if variables is None:
88+
variables = {}
89+
90+
def _eval(node):
91+
match node:
92+
case ast.Expression():
93+
return _eval(node.body)
94+
95+
case ast.BinOp():
96+
left = _eval(node.left)
97+
right = _eval(node.right)
98+
return OPERATORS[type(node.op)](left, right)
99+
100+
case ast.UnaryOp():
101+
operand = _eval(node.operand)
102+
return OPERATORS[type(node.op)](operand)
103+
104+
case ast.Subscript():
105+
value = _eval(node.value)
106+
match node.slice:
107+
case ast.Slice():
108+
start = _eval(node.slice.lower) if node.slice.lower else None
109+
stop = _eval(node.slice.upper) if node.slice.upper else None
110+
step = _eval(node.slice.step) if node.slice.step else None
111+
return value[start:stop:step]
112+
case _:
113+
key = _eval(node.slice)
114+
return value[key]
115+
116+
case ast.Compare():
117+
left = _eval(node.left)
118+
right = _eval(node.comparators[0])
119+
return OPERATORS[type(node.ops[0])](left, right)
120+
121+
case ast.BoolOp():
122+
values = [_eval(v) for v in node.values]
123+
match node.op:
124+
case ast.And():
125+
return all(values)
126+
case ast.Or():
127+
return any(values)
128+
129+
case ast.IfExp():
130+
test = _eval(node.test)
131+
return _eval(node.body) if test else _eval(node.orelse)
132+
133+
case ast.Constant():
134+
return node.value
135+
136+
case ast.Name():
137+
var_name = node.id
138+
if var_name in variables:
139+
return variables[var_name]
140+
elif var_name in {"True", "False"}:
141+
return eval(var_name)
142+
raise ValueError(f"Unsupported name: {var_name}")
143+
144+
case ast.Call():
145+
func_name = node.func.id if isinstance(node.func, ast.Name) else None
146+
if func_name in SAFE_FUNCTIONS:
147+
func = SAFE_FUNCTIONS[func_name]
148+
args = [_eval(arg) for arg in node.args]
149+
return func(*args)
150+
else:
151+
raise ValueError(f"Function {func_name} is not allowed")
152+
153+
case ast.List():
154+
return [_eval(elt) for elt in node.elts]
155+
156+
case ast.Dict():
157+
return {_eval(k): _eval(v) for k, v in zip(node.keys, node.values)}
158+
159+
case ast.Attribute():
160+
value = _eval(node.value)
161+
return getattr(value, node.attr)
162+
163+
case _:
164+
raise ValueError(f"Unsupported node type: {type(node)}")
165+
166+
if var_prefix:
167+
var_prefix += '.'
168+
expression = re.sub(rf'{re.escape(var_prefix)}(\w+)', r'\1', expression)
169+
170+
tree = ast.parse(expression, mode='eval')
171+
return _eval(tree.body)

0 commit comments

Comments
 (0)