Skip to content

Commit 8612740

Browse files
authored
Type propagation through dict and tuples (#140)
* Initial go at the tests for this * Deal with dictionaries and tuples for type following * Fix up flake8 and type errors * Update github action references * Deal with python 3.8
1 parent 2891ca1 commit 8612740

File tree

4 files changed

+113
-14
lines changed

4 files changed

+113
-14
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ jobs:
1212
flake8:
1313
runs-on: ubuntu-latest
1414
steps:
15-
- uses: actions/checkout@v3
15+
- uses: actions/checkout@v4
1616
- name: Set up Python 3.8
17-
uses: actions/setup-python@v4
17+
uses: actions/setup-python@v5
1818
with:
1919
python-version: 3.8
2020
- name: Install dependencies
@@ -41,9 +41,9 @@ jobs:
4141
python-version: [3.8, 3.9, "3.10", 3.11, 3.12]
4242

4343
steps:
44-
- uses: actions/checkout@v3
44+
- uses: actions/checkout@v4
4545
- name: Set up Python ${{ matrix.python-version }}
46-
uses: actions/setup-python@v4
46+
uses: actions/setup-python@v5
4747
with:
4848
python-version: ${{ matrix.python-version }}
4949
- name: Install dependencies

func_adl/type_based_replacement.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,16 @@ def stream(self) -> ObjectStream[S]:
496496

497497
def lookup_type(self, name: Union[str, object]) -> Type:
498498
"Return the type for a node, Any if we do not know about it"
499-
return self._found_types.get(name, Any) # type: ignore
499+
t = self._found_types.get(name, None)
500+
if t is not None:
501+
return t # type: ignore
502+
if not isinstance(name, ast.AST):
503+
return Any # type: ignore
504+
505+
# It could be we can determine the type from this ast.
506+
return Any # type: ignore
507+
508+
return t # type: ignore
500509

501510
def process_method_call_on_stream_obj(
502511
self,
@@ -883,9 +892,25 @@ def visit_IfExp(self, node: ast.IfExp) -> Any:
883892
def visit_Subscript(self, node: ast.Subscript) -> Any:
884893
t_node = self.generic_visit(node)
885894
assert isinstance(t_node, ast.Subscript)
886-
inner_type = unwrap_iterable(self.lookup_type(t_node.value))
887-
self._found_types[node] = inner_type
888-
self._found_types[t_node] = inner_type
895+
if isinstance(t_node.value, ast.Tuple):
896+
slice = t_node.slice
897+
# This if statement can be removed when we no longer support 3.8.
898+
if isinstance(slice, ast.Index):
899+
slice = slice.value # type: ignore
900+
if not isinstance(slice, ast.Constant):
901+
raise ValueError(
902+
f"Slices must be indexable constants only - {ast.dump(slice)} is not "
903+
"valid."
904+
)
905+
index = slice.value
906+
if len(t_node.value.elts) <= index:
907+
raise ValueError(f"Index {index} out of range for {ast.dump(node.value)}")
908+
self._found_types[node] = self.lookup_type(t_node.value.elts[index])
909+
self._found_types[t_node] = self.lookup_type(t_node.value.elts[index])
910+
else:
911+
inner_type = unwrap_iterable(self.lookup_type(t_node.value))
912+
self._found_types[node] = inner_type
913+
self._found_types[t_node] = inner_type
889914
return t_node
890915

891916
def visit_Name(self, node: ast.Name) -> ast.Name:
@@ -919,6 +944,22 @@ def visit_NameConstant(self, node: ast.NameConstant) -> Any: # pragma: no cover
919944
self._found_types[node] = bool
920945
return node
921946

947+
def visit_Attribute(self, node: ast.Attribute) -> Any:
948+
t_node = self.generic_visit(node)
949+
assert isinstance(t_node, ast.Attribute)
950+
# If this is a dict reference, then figure out what the
951+
# type is for that value of the dict.
952+
if isinstance(t_node.value, ast.Dict):
953+
key = t_node.attr
954+
key_index = [
955+
e for e, k in enumerate(t_node.value.keys) if k.value == key # type: ignore
956+
]
957+
if len(key_index) == 0:
958+
raise ValueError(f"Key {key} not found in dict expression!!")
959+
value = t_node.value.values[key_index[0]]
960+
self._found_types[node] = self.lookup_type(value)
961+
return t_node
962+
922963
tt = type_transformer(o_stream)
923964
r_a = tt.visit(a)
924965

func_adl/util_types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_inherited(t: Type) -> Type:
4646
elif hasattr(t, "__origin__") and hasattr(t.__origin__, "__orig_bases__"):
4747
base_classes = t.__origin__.__orig_bases__
4848
else:
49-
return Any
49+
return Any # type: ignore
5050

5151
r = base_classes[0] # type: ignore
5252

@@ -77,7 +77,7 @@ def unwrap_iterable(t: Type) -> Type:
7777
t = get_inherited(t)
7878

7979
if t == Any:
80-
return Any
80+
return Any # type: ignore
8181

8282
a = get_args(t)
8383
assert len(a) == 1, f"Coding error - expected iterable type with a parameter, got {t}"
@@ -168,7 +168,7 @@ def resolve_type_vars(
168168
s = build_type_dict_from_type(context_type, at_class)
169169
except TypeError:
170170
s = {}
171-
return _resolve_type(parameterized_type, s)
171+
return _resolve_type(parameterized_type, s) # type: ignore
172172

173173

174174
def get_class_name(t: Type) -> str:
@@ -206,7 +206,7 @@ class object where the method is defined, and the method object. If there is no
206206
# Check for templated classes
207207
# TODO: Use inspect.getmro
208208
if not hasattr(class_object, "__mro__"):
209-
class_object = get_origin(class_object)
209+
class_object = get_origin(class_object) # type: ignore
210210

211211
# Walk the resolution hierarchy to find the method
212212
found_obj = None
@@ -223,6 +223,6 @@ class object where the method is defined, and the method object. If there is no
223223
if found_method == m:
224224
found_obj = c
225225
else:
226-
return (found_obj, found_method)
226+
return (found_obj, found_method) # type: ignore
227227

228-
return (found_obj, found_method)
228+
return (found_obj, found_method) # type: ignore

tests/test_type_based_replacement.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,64 @@ def test_collection_Select(caplog):
503503
assert len(caplog.text) == 0
504504

505505

506+
def test_dictionary():
507+
"Check that we can type-follow through dictionaries"
508+
509+
s = ast_lambda("{'jets': e.Jets()}.jets.Select(lambda j: j.pt())")
510+
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
511+
512+
new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)
513+
514+
assert expr_type == Iterable[float]
515+
516+
517+
def test_dictionary_bad_key():
518+
"Check that we can type-follow through dictionaries"
519+
520+
with pytest.raises(ValueError) as e:
521+
s = ast_lambda("{'jets': e.Jets()}.jetsss.Select(lambda j: j.pt())")
522+
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
523+
524+
new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)
525+
526+
assert "jetsss" in str(e)
527+
528+
529+
def test_indexed_tuple():
530+
"Check that we can type-follow through dictionaries"
531+
532+
s = ast_lambda("(e.Jets(),)[0].Select(lambda j: j.pt())")
533+
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
534+
535+
new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)
536+
537+
assert expr_type == Iterable[float]
538+
539+
540+
def test_indexed_tuple_out_of_bounds():
541+
"Check that we can type-follow through dictionaries"
542+
543+
with pytest.raises(ValueError) as e:
544+
s = ast_lambda("(e.Jets(),)[3].Select(lambda j: j.pt())")
545+
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
546+
547+
new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)
548+
549+
assert "3" in str(e)
550+
551+
552+
def test_indexed_tuple_bad_slice():
553+
"Check that we can type-follow through dictionaries"
554+
555+
with pytest.raises(ValueError) as e:
556+
s = ast_lambda("(e.Jets(),)[0:1].Select(lambda j: j.pt())")
557+
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
558+
559+
new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)
560+
561+
assert "is not valid" in str(e)
562+
563+
506564
def test_collection_Select_meta(caplog):
507565
"A simple collection"
508566
caplog.set_level(logging.WARNING)

0 commit comments

Comments
 (0)