diff --git a/jmespath/visitor.py b/jmespath/visitor.py index b3e846b7..7db2b3b0 100644 --- a/jmespath/visitor.py +++ b/jmespath/visitor.py @@ -56,6 +56,10 @@ def _is_actual_number(x): return isinstance(x, Number) +def _is_sequence(x): + return isinstance(x, (list, tuple)) + + class Options(object): """Options to control how a JMESPath function is evaluated.""" def __init__(self, dict_cls=None, custom_functions=None): @@ -172,7 +176,7 @@ def visit_function_expression(self, node, value): def visit_filter_projection(self, node, value): base = self.visit(node['children'][0], value) - if not isinstance(base, list): + if not _is_sequence(base): return None comparator_node = node['children'][2] collected = [] @@ -181,20 +185,20 @@ def visit_filter_projection(self, node, value): current = self.visit(node['children'][1], element) if current is not None: collected.append(current) - return collected + return type(base)(collected) def visit_flatten(self, node, value): base = self.visit(node['children'][0], value) - if not isinstance(base, list): - # Can't flatten the object if it's not a list. + if not _is_sequence(base): + # Can't flatten the object if it's not a supported sequence type. return None merged_list = [] for element in base: - if isinstance(element, list): + if _is_sequence(element): merged_list.extend(element) else: merged_list.append(element) - return merged_list + return type(base)(merged_list) def visit_identity(self, node, value): return value @@ -202,7 +206,7 @@ def visit_identity(self, node, value): def visit_index(self, node, value): # Even though we can index strings, we don't # want to support that. - if not isinstance(value, list): + if not _is_sequence(value): return None try: return value[node['value']] @@ -216,7 +220,7 @@ def visit_index_expression(self, node, value): return result def visit_slice(self, node, value): - if not isinstance(value, list): + if not _is_sequence(value): return None s = slice(*node['children']) return value[s] @@ -271,14 +275,14 @@ def visit_pipe(self, node, value): def visit_projection(self, node, value): base = self.visit(node['children'][0], value) - if not isinstance(base, list): + if not _is_sequence(base): return None collected = [] for element in base: current = self.visit(node['children'][1], element) if current is not None: collected.append(current) - return collected + return type(base)(collected) def visit_value_projection(self, node, value): base = self.visit(node['children'][0], value) diff --git a/tests/test_parser.py b/tests/test_parser.py index 5af5ce7f..620cf755 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -40,7 +40,14 @@ def test_index(self): parsed = self.parser.parse('foo[1]') self.assertEqual( parsed.search({'foo': ['zero', 'one', 'two']}), - 'one') + 'one', + "Fail: Index on lists" + ) + self.assertEqual( + parsed.search({'foo': ('zero', 'one', 'two')}), + 'one', + "Fail: Index on tuples" + ) def test_quoted_subexpression(self): self.assert_parsed_ast('"foo"."bar"', @@ -52,13 +59,27 @@ def test_wildcard(self): parsed = self.parser.parse('foo[*]') self.assertEqual( parsed.search({'foo': ['zero', 'one', 'two']}), - ['zero', 'one', 'two']) + ['zero', 'one', 'two'], + "Fail: Wildcard on lists" + ) + self.assertEqual( + parsed.search({'foo': ('zero', 'one', 'two')}), + ('zero', 'one', 'two'), + "Fail: Wildcard on tuples" + ) def test_wildcard_with_children(self): parsed = self.parser.parse('foo[*].bar') self.assertEqual( parsed.search({'foo': [{'bar': 'one'}, {'bar': 'two'}]}), - ['one', 'two']) + ['one', 'two'], + "Fail: Wildcard with children on lists" + ) + self.assertEqual( + parsed.search({'foo': ({'bar': 'one'}, {'bar': 'two'})}), + ('one', 'two'), + "Fail: Wildcard with children on tuples" + ) def test_or_expression(self): parsed = self.parser.parse('foo || bar') @@ -176,36 +197,77 @@ def test_bad_unicode_string(self): class TestParserWildcards(unittest.TestCase): def setUp(self): self.parser = parser.Parser() - self.data = { + self.data_with_lists = { 'foo': [ {'bar': [{'baz': 'one'}, {'baz': 'two'}]}, {'bar': [{'baz': 'three'}, {'baz': 'four'}, {'baz': 'five'}]}, ] } + self.data_with_tuples = { + 'foo': ( + {'bar': ({'baz': 'one'}, {'baz': 'two'})}, + {'bar': ({'baz': 'three'}, {'baz': 'four'}, {'baz': 'five'})}, + ) + } + self.data_with_lists_and_tuples = { + 'foo': ( + {'bar': [{'baz': 'one'}, {'baz': 'two'}]}, + {'bar': ({'baz': 'three'}, {'baz': 'four'}, {'baz': 'five'})}, + ) + } def test_multiple_index_wildcards(self): parsed = self.parser.parse('foo[*].bar[*].baz') - self.assertEqual(parsed.search(self.data), - [['one', 'two'], ['three', 'four', 'five']]) + self.assertEqual(parsed.search(self.data_with_lists), + [['one', 'two'], ['three', 'four', 'five']], + "Fail: Multiple index wildcards on lists") + self.assertEqual(parsed.search(self.data_with_tuples), + (('one', 'two'), ('three', 'four', 'five')), + "Fail: Multiple index wildcards on tuples") + self.assertEqual(parsed.search(self.data_with_lists_and_tuples), + (['one', 'two'], ('three', 'four', 'five')), + "Fail: Multiple index wildcards on lists and tuples") def test_wildcard_mix_with_indices(self): parsed = self.parser.parse('foo[*].bar[0].baz') - self.assertEqual(parsed.search(self.data), - ['one', 'three']) + self.assertEqual(parsed.search(self.data_with_lists), + ['one', 'three'], + "Fail: Wildcard mix with indices on lists") + self.assertEqual(parsed.search(self.data_with_tuples), + ('one', 'three'), + "Fail: Wildcard mix with indices on tuples") + self.assertEqual(parsed.search(self.data_with_lists_and_tuples), + ('one', 'three'), + "Fail: Wildcard mix with indices on lists and tuples") def test_wildcard_mix_last(self): parsed = self.parser.parse('foo[0].bar[*].baz') - self.assertEqual(parsed.search(self.data), - ['one', 'two']) + self.assertEqual(parsed.search(self.data_with_lists), + ['one', 'two'], + "Fail: Wildcard mix last on lists") + self.assertEqual(parsed.search(self.data_with_tuples), + ('one', 'two'), + "Fail: Wildcard mix last on tuples") + self.assertEqual(parsed.search(self.data_with_lists_and_tuples), + ['one', 'two'], + "Fail: Wildcard mix last on lists and tuples") def test_indices_out_of_bounds(self): parsed = self.parser.parse('foo[*].bar[2].baz') - self.assertEqual(parsed.search(self.data), - ['five']) + self.assertEqual(parsed.search(self.data_with_lists), + ['five'], + "Fail: Indices out of bounds on lists") + self.assertEqual(parsed.search(self.data_with_tuples), + ('five',), + "Fail: Indices out of bounds on tuples") + self.assertEqual(parsed.search(self.data_with_lists_and_tuples), + ('five',), + "Fail: Indices out of bounds on lists and tuples") def test_root_indices(self): parsed = self.parser.parse('[0]') - self.assertEqual(parsed.search(['one', 'two']), 'one') + self.assertEqual(parsed.search(['one', 'two']), 'one', "Fail: Root indices on lists") + self.assertEqual(parsed.search(('one', 'two')), 'one', "Fail: Root indices on tuples") def test_root_wildcard(self): parsed = self.parser.parse('*.foo') @@ -270,28 +332,54 @@ def test_wildcard_with_multiselect(self): class TestMergedLists(unittest.TestCase): def setUp(self): self.parser = parser.Parser() - self.data = { + self.data_with_lists = { "foo": [ [["one", "two"], ["three", "four"]], [["five", "six"], ["seven", "eight"]], [["nine"], ["ten"]] ] } + self.data_with_tuples = { + "foo": ( + (("one", "two"), ("three", "four")), + (("five", "six"), ("seven", "eight")), + (("nine",), ("ten",)) + ) + } + self.data_with_lists_and_tuples = { + "foo": [ + (("one", "two"), ("three", "four")), + (("five", "six"), ["seven", "eight"]), + [("nine",), ("ten",)] + ] + } def test_merge_with_indices(self): parsed = self.parser.parse('foo[][0]') - match = parsed.search(self.data) - self.assertEqual(match, ["one", "three", "five", "seven", - "nine", "ten"]) + self.assertEqual(parsed.search(self.data_with_lists), + ["one", "three", "five", "seven", "nine", "ten"], + "Fail: Merge with indices on lists") + self.assertEqual(parsed.search(self.data_with_tuples), + ("one", "three", "five", "seven", "nine", "ten"), + "Fail: Merge with indices on tuples") + self.assertEqual(parsed.search(self.data_with_lists_and_tuples), + ["one", "three", "five", "seven", "nine", "ten"], + "Fail: Merge with indices on lists and tuples") def test_trailing_merged_operator(self): parsed = self.parser.parse('foo[]') - match = parsed.search(self.data) - self.assertEqual( - match, - [["one", "two"], ["three", "four"], - ["five", "six"], ["seven", "eight"], - ["nine"], ["ten"]]) + self.assertEqual(parsed.search(self.data_with_lists), + [["one", "two"], ["three", "four"], ["five", "six"], + ["seven", "eight"], ["nine"], ["ten"]], + "Fail: Trailing merged operator on lists") + self.assertEqual(parsed.search(self.data_with_tuples), + (("one", "two"), ("three", "four"), ("five", "six"), + ("seven", "eight"), ("nine",), ("ten",)), + "Fail: Trailing merged operator on lists") + self.assertEqual(parsed.search(self.data_with_lists_and_tuples), + [("one", "two"), ("three", "four"), ("five", "six"), + ["seven", "eight"], ("nine",), ("ten",)], + "Fail: Trailing merged operator on lists and tuples") class TestParserCaching(unittest.TestCase):