Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions Orange/widgets/data/owfeatureconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def data(self, index, role=Qt.DisplayRole):
return super().data(index, role)


def freevars(exp, env):
def freevars(exp: ast.AST, env: List[str]):
"""
Return names of all free variables in a parsed (expression) AST.

Expand Down Expand Up @@ -402,11 +402,14 @@ def freevars(exp, env):
elif etype == ast.Lambda:
args = exp.args
assert isinstance(args, ast.arguments)
argnames = [a.arg for a in args.args]
argnames += [args.vararg.arg] if args.vararg else []
argnames += [a.arg for a in args.kwonlyargs] if args.kwonlyargs else []
argnames += [args.kwarg] if args.kwarg else []
return freevars(exp.body, env + argnames)
arg_names = [a.arg for a in chain(args.posonlyargs, args.args)]
arg_names += [args.vararg.arg] if args.vararg else []
arg_names += [a.arg for a in args.kwonlyargs] if args.kwonlyargs else []
arg_names += [args.kwarg.arg] if args.kwarg else []
vars_ = chain.from_iterable(
freevars(e, env) for e in chain(args.defaults, args.kw_defaults)
)
return list(vars_) + freevars(exp.body, env + arg_names)
elif etype == ast.IfExp:
return (freevars(exp.test, env) + freevars(exp.body, env) +
freevars(exp.orelse, env))
Expand All @@ -420,7 +423,7 @@ def freevars(exp, env):
vars_ = []
for gen in exp.generators:
target_names = freevars(gen.target, []) # assigned names
vars_iter = freevars(gen.iter, env)
vars_iter = freevars(gen.iter, env + env_ext)
env_ext += target_names
vars_ifs = list(chain(*(freevars(ifexp, env + target_names)
for ifexp in gen.ifs or [])))
Expand Down Expand Up @@ -500,7 +503,7 @@ def is_valid_item(self, setting, item, attrs, metas):
for var in metas:
available[sanitized_name(var)] = None

if freevars(exp_ast, available):
if freevars(exp_ast, list(available)):
return False
return True

Expand Down Expand Up @@ -944,16 +947,12 @@ def validate_exp(exp):
"""
Validate an `ast.AST` expression.

Only expressions with no list,set,dict,generator comprehensions
are accepted.

Parameters
----------
exp : ast.AST
A parsed abstract syntax tree

"""
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches,too-many-return-statements
if not isinstance(exp, ast.AST):
raise TypeError("exp is not a 'ast.AST' instance")

Expand All @@ -966,20 +965,28 @@ def validate_exp(exp):
return all(map(validate_exp, [exp.left, exp.right]))
elif etype == ast.UnaryOp:
return validate_exp(exp.operand)
elif etype == ast.Lambda:
return all(validate_exp(e) for e in exp.args.defaults) and \
all(validate_exp(e) for e in exp.args.kw_defaults) and \
validate_exp(exp.body)
elif etype == ast.IfExp:
return all(map(validate_exp, [exp.test, exp.body, exp.orelse]))
elif etype == ast.Dict:
return all(map(validate_exp, chain(exp.keys, exp.values)))
elif etype == ast.Set:
return all(map(validate_exp, exp.elts))
elif etype in (ast.SetComp, ast.ListComp, ast.GeneratorExp):
return validate_exp(exp.elt) and all(map(validate_exp, exp.generators))
elif etype == ast.DictComp:
return validate_exp(exp.key) and validate_exp(exp.value) and \
all(map(validate_exp, exp.generators))
elif etype == ast.Compare:
return all(map(validate_exp, [exp.left] + exp.comparators))
elif etype == ast.Call:
subexp = chain([exp.func], exp.args or [],
[k.value for k in exp.keywords or []])
return all(map(validate_exp, subexp))
elif etype == ast.Starred:
assert isinstance(exp.ctx, ast.Load)
return validate_exp(exp.value)
elif etype in [ast.Num, ast.Str, ast.Bytes, ast.Ellipsis, ast.NameConstant]:
return True
Expand All @@ -990,7 +997,6 @@ def validate_exp(exp):
elif etype == ast.Subscript:
return all(map(validate_exp, [exp.value, exp.slice]))
elif etype in {ast.List, ast.Tuple}:
assert isinstance(exp.ctx, ast.Load)
return all(map(validate_exp, exp.elts))
elif etype == ast.Name:
return True
Expand All @@ -1003,6 +1009,9 @@ def validate_exp(exp):
return validate_exp(exp.value)
elif etype == ast.keyword:
return validate_exp(exp.value)
elif etype == ast.comprehension and not exp.is_async:
return validate_exp(exp.target) and validate_exp(exp.iter) and \
all(map(validate_exp, exp.ifs))
else:
raise ValueError(exp)

Expand Down Expand Up @@ -1173,9 +1182,9 @@ def make_lambda(expression, args, env=None):
"bin", "bool", "bytearray", "bytes", "chr", "complex", "dict",
"divmod", "enumerate", "filter", "float", "format", "frozenset",
"getattr", "hasattr", "hash", "hex", "id", "int", "iter", "len",
"list", "map", "memoryview", "next", "object",
"list", "map", "max", "memoryview", "min", "next", "object",
"oct", "ord", "pow", "range", "repr", "reversed", "round",
"set", "slice", "sorted", "str", "tuple", "type",
"set", "slice", "sorted", "str", "sum", "tuple", "type",
"zip"
]

Expand Down Expand Up @@ -1209,9 +1218,6 @@ def make_lambda(expression, args, env=None):
"nanargmin": lambda *args: np.nanargmin(args),
"nanvar": lambda *args: np.nanvar(args),
"mean": lambda *args: np.mean(args),
"min": lambda *args: np.min(args),
"max": lambda *args: np.max(args),
"sum": lambda *args: np.sum(args),
"std": lambda *args: np.std(args),
"median": lambda *args: np.median(args),
"cumsum": lambda *args: np.cumsum(args),
Expand Down
28 changes: 18 additions & 10 deletions Orange/widgets/data/tests/test_owfeatureconstructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,34 @@ def freevars_(source, env=None):
self.assertEqual(freevars_("{a, b}"), ["a", "b"])
self.assertEqual(freevars_("0 if abs(a) < 0.1 else b", ["abs"]),
["a", "b"])
self.assertEqual(freevars_("lambda: a", []), ["a"])
self.assertEqual(freevars_("lambda: a", ["a"]), [])
self.assertEqual(freevars_("lambda a: b + 1"), ["b"])
self.assertEqual(freevars_("lambda a: b + 1", ["b"]), [])
self.assertEqual(freevars_("lambda a: a + 1"), [])
self.assertEqual(freevars_("(lambda a: a + 1)(a)"), ["a"])
self.assertEqual(freevars_("lambda a, *arg: arg + (a,)"), [])
self.assertEqual(freevars_("lambda a, *arg, **kwargs: arg + (a,)"), [])

self.assertEqual(freevars_("lambda a: a + c", []), ["c"])
self.assertEqual(freevars_("lambda a: a + c", ["c"]), [])
self.assertEqual(freevars_("lambda a, b=k: a + c", []), ["k", "c"])
self.assertEqual(freevars_("lambda *a, b=k: a + c", []), ["k", "c"])
self.assertEqual(freevars_("lambda a,/, b=k: a + c", []), ["k", "c"])
self.assertEqual(freevars_("lambda a,/, b=k, **kwg: a + c and kwg", []),
["k", "c"])
self.assertEqual(freevars_("[a for a in b]"), ["b"])
self.assertEqual(freevars_("[a for a, k in b]"), ["b"])
self.assertEqual(freevars_("[(a, j) for a in b]"), ["j", "b"])
self.assertEqual(freevars_("[a for k in b for a in k]"), ["b"])
self.assertEqual(freevars_("[a for k in b if k for a in k if a]"),
["b"])
self.assertEqual(freevars_("[a for k in b if kk for a in k if aa]"),
["b", "kk", "aa"])
self.assertEqual(freevars_("[1 + a for c in b if c]"), ["a", "b"])
self.assertEqual(freevars_("{a for _ in [] if b}"), ["a", "b"])
self.assertEqual(freevars_("{a for _ in [] if b}", ["a", "b"]), [])

def test_validate_exp(self):

stmt = ast.parse("1", mode="single")
with self.assertRaises(ValueError):
validate_exp(stmt)
Expand Down Expand Up @@ -272,16 +286,10 @@ def validate_(source):
self.assertTrue(validate_("[]"))

with self.assertRaises(ValueError):
validate_("[a for a in s]")

with self.assertRaises(ValueError):
validate_("(a for a in s)")

with self.assertRaises(ValueError):
validate_("{a for a in s}")
validate_("[i async for i in s]")

with self.assertRaises(ValueError):
validate_("{a:1 for a in s}")
validate_("(i async for i in s)")


class FeatureFuncTest(unittest.TestCase):
Expand Down