Skip to content

Commit 1ea713d

Browse files
committed
owfeatureconstructor: Enable comprehensions and lambdas in expressions
1 parent 9c02c20 commit 1ea713d

File tree

2 files changed

+41
-24
lines changed

2 files changed

+41
-24
lines changed

Orange/widgets/data/owfeatureconstructor.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def data(self, index, role=Qt.DisplayRole):
369369
return super().data(index, role)
370370

371371

372-
def freevars(exp, env):
372+
def freevars(exp: ast.AST, env: List[str]):
373373
"""
374374
Return names of all free variables in a parsed (expression) AST.
375375
@@ -402,11 +402,14 @@ def freevars(exp, env):
402402
elif etype == ast.Lambda:
403403
args = exp.args
404404
assert isinstance(args, ast.arguments)
405-
argnames = [a.arg for a in args.args]
406-
argnames += [args.vararg.arg] if args.vararg else []
407-
argnames += [a.arg for a in args.kwonlyargs] if args.kwonlyargs else []
408-
argnames += [args.kwarg] if args.kwarg else []
409-
return freevars(exp.body, env + argnames)
405+
arg_names = [a.arg for a in chain(args.posonlyargs, args.args)]
406+
arg_names += [args.vararg.arg] if args.vararg else []
407+
arg_names += [a.arg for a in args.kwonlyargs] if args.kwonlyargs else []
408+
arg_names += [args.kwarg.arg] if args.kwarg else []
409+
vars_ = chain.from_iterable(
410+
freevars(e, env) for e in chain(args.defaults, args.kw_defaults)
411+
)
412+
return list(vars_) + freevars(exp.body, env + arg_names)
410413
elif etype == ast.IfExp:
411414
return (freevars(exp.test, env) + freevars(exp.body, env) +
412415
freevars(exp.orelse, env))
@@ -420,7 +423,7 @@ def freevars(exp, env):
420423
vars_ = []
421424
for gen in exp.generators:
422425
target_names = freevars(gen.target, []) # assigned names
423-
vars_iter = freevars(gen.iter, env)
426+
vars_iter = freevars(gen.iter, env + env_ext)
424427
env_ext += target_names
425428
vars_ifs = list(chain(*(freevars(ifexp, env + target_names)
426429
for ifexp in gen.ifs or [])))
@@ -500,7 +503,7 @@ def is_valid_item(self, setting, item, attrs, metas):
500503
for var in metas:
501504
available[sanitized_name(var)] = None
502505

503-
if freevars(exp_ast, available):
506+
if freevars(exp_ast, list(available)):
504507
return False
505508
return True
506509

@@ -944,14 +947,10 @@ def validate_exp(exp):
944947
"""
945948
Validate an `ast.AST` expression.
946949
947-
Only expressions with no list,set,dict,generator comprehensions
948-
are accepted.
949-
950950
Parameters
951951
----------
952952
exp : ast.AST
953953
A parsed abstract syntax tree
954-
955954
"""
956955
# pylint: disable=too-many-branches
957956
if not isinstance(exp, ast.AST):
@@ -966,20 +965,28 @@ def validate_exp(exp):
966965
return all(map(validate_exp, [exp.left, exp.right]))
967966
elif etype == ast.UnaryOp:
968967
return validate_exp(exp.operand)
968+
elif etype == ast.Lambda:
969+
return all(validate_exp(e) for e in exp.args.defaults) and \
970+
all(validate_exp(e) for e in exp.args.kw_defaults) and \
971+
validate_exp(exp.body)
969972
elif etype == ast.IfExp:
970973
return all(map(validate_exp, [exp.test, exp.body, exp.orelse]))
971974
elif etype == ast.Dict:
972975
return all(map(validate_exp, chain(exp.keys, exp.values)))
973976
elif etype == ast.Set:
974977
return all(map(validate_exp, exp.elts))
978+
elif etype in (ast.SetComp, ast.ListComp, ast.GeneratorExp):
979+
return validate_exp(exp.elt) and all(map(validate_exp, exp.generators))
980+
elif etype == ast.DictComp:
981+
return validate_exp(exp.key) and validate_exp(exp.value) and \
982+
all(map(validate_exp, exp.generators))
975983
elif etype == ast.Compare:
976984
return all(map(validate_exp, [exp.left] + exp.comparators))
977985
elif etype == ast.Call:
978986
subexp = chain([exp.func], exp.args or [],
979987
[k.value for k in exp.keywords or []])
980988
return all(map(validate_exp, subexp))
981989
elif etype == ast.Starred:
982-
assert isinstance(exp.ctx, ast.Load)
983990
return validate_exp(exp.value)
984991
elif etype in [ast.Num, ast.Str, ast.Bytes, ast.Ellipsis, ast.NameConstant]:
985992
return True
@@ -990,7 +997,6 @@ def validate_exp(exp):
990997
elif etype == ast.Subscript:
991998
return all(map(validate_exp, [exp.value, exp.slice]))
992999
elif etype in {ast.List, ast.Tuple}:
993-
assert isinstance(exp.ctx, ast.Load)
9941000
return all(map(validate_exp, exp.elts))
9951001
elif etype == ast.Name:
9961002
return True
@@ -1003,6 +1009,9 @@ def validate_exp(exp):
10031009
return validate_exp(exp.value)
10041010
elif etype == ast.keyword:
10051011
return validate_exp(exp.value)
1012+
elif etype == ast.comprehension and not exp.is_async:
1013+
return validate_exp(exp.target) and validate_exp(exp.iter) and \
1014+
all(map(validate_exp, exp.ifs))
10061015
else:
10071016
raise ValueError(exp)
10081017

Orange/widgets/data/tests/test_owfeatureconstructor.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -226,20 +226,34 @@ def freevars_(source, env=None):
226226
self.assertEqual(freevars_("{a, b}"), ["a", "b"])
227227
self.assertEqual(freevars_("0 if abs(a) < 0.1 else b", ["abs"]),
228228
["a", "b"])
229+
self.assertEqual(freevars_("lambda: a", []), ["a"])
230+
self.assertEqual(freevars_("lambda: a", ["a"]), [])
229231
self.assertEqual(freevars_("lambda a: b + 1"), ["b"])
230232
self.assertEqual(freevars_("lambda a: b + 1", ["b"]), [])
231233
self.assertEqual(freevars_("lambda a: a + 1"), [])
232234
self.assertEqual(freevars_("(lambda a: a + 1)(a)"), ["a"])
233235
self.assertEqual(freevars_("lambda a, *arg: arg + (a,)"), [])
234236
self.assertEqual(freevars_("lambda a, *arg, **kwargs: arg + (a,)"), [])
235-
237+
self.assertEqual(freevars_("lambda a: a + c", []), ["c"])
238+
self.assertEqual(freevars_("lambda a: a + c", ["c"]), [])
239+
self.assertEqual(freevars_("lambda a, b=k: a + c", []), ["k", "c"])
240+
self.assertEqual(freevars_("lambda *a, b=k: a + c", []), ["k", "c"])
241+
self.assertEqual(freevars_("lambda a,/, b=k: a + c", []), ["k", "c"])
242+
self.assertEqual(freevars_("lambda a,/, b=k, **kwg: a + c and kwg", []),
243+
["k", "c"])
236244
self.assertEqual(freevars_("[a for a in b]"), ["b"])
245+
self.assertEqual(freevars_("[a for a, k in b]"), ["b"])
246+
self.assertEqual(freevars_("[(a, j) for a in b]"), ["j", "b"])
247+
self.assertEqual(freevars_("[a for k in b for a in k]"), ["b"])
248+
self.assertEqual(freevars_("[a for k in b if k for a in k if a]"),
249+
["b"])
250+
self.assertEqual(freevars_("[a for k in b if kk for a in k if aa]"),
251+
["b", "kk", "aa"])
237252
self.assertEqual(freevars_("[1 + a for c in b if c]"), ["a", "b"])
238253
self.assertEqual(freevars_("{a for _ in [] if b}"), ["a", "b"])
239254
self.assertEqual(freevars_("{a for _ in [] if b}", ["a", "b"]), [])
240255

241256
def test_validate_exp(self):
242-
243257
stmt = ast.parse("1", mode="single")
244258
with self.assertRaises(ValueError):
245259
validate_exp(stmt)
@@ -272,16 +286,10 @@ def validate_(source):
272286
self.assertTrue(validate_("[]"))
273287

274288
with self.assertRaises(ValueError):
275-
validate_("[a for a in s]")
276-
277-
with self.assertRaises(ValueError):
278-
validate_("(a for a in s)")
279-
280-
with self.assertRaises(ValueError):
281-
validate_("{a for a in s}")
289+
validate_("[i async for i in s]")
282290

283291
with self.assertRaises(ValueError):
284-
validate_("{a:1 for a in s}")
292+
validate_("(i async for i in s)")
285293

286294

287295
class FeatureFuncTest(unittest.TestCase):

0 commit comments

Comments
 (0)