Skip to content

Commit 5f23cfb

Browse files
authored
Deep-Type Dictionaries (#142)
* Dictionaries are now named tuples for typing * Fix up flake8 error * Add dummy test to make sure we don't forget. * Fix up missing return * Add test to assure Select typing * Do proper lookup when working with a nameclass
1 parent 8612740 commit 5f23cfb

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

func_adl/type_based_replacement.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import inspect
66
import logging
77
import sys
8-
from dataclasses import dataclass
8+
from dataclasses import dataclass, is_dataclass, make_dataclass
99
from typing import (
1010
Any,
1111
Callable,
@@ -923,6 +923,19 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
923923
self._found_types[node] = Any
924924
return node
925925

926+
def visit_Dict(self, node: ast.Dict) -> Any:
927+
t_node = self.generic_visit(node)
928+
assert isinstance(t_node, ast.Dict)
929+
930+
fields: List[Tuple[str, type]] = [
931+
(ast.literal_eval(f), self.lookup_type(v)) # type: ignore
932+
for f, v in zip(t_node.keys, t_node.values)
933+
]
934+
dict_dataclass = make_dataclass("dict_dataclass", fields)
935+
936+
self._found_types[t_node] = dict_dataclass
937+
return t_node
938+
926939
def visit_Constant(self, node: ast.Constant) -> Any:
927940
self._found_types[node] = type(node.value)
928941
return node
@@ -958,6 +971,11 @@ def visit_Attribute(self, node: ast.Attribute) -> Any:
958971
raise ValueError(f"Key {key} not found in dict expression!!")
959972
value = t_node.value.values[key_index[0]]
960973
self._found_types[node] = self.lookup_type(value)
974+
elif ((dc := self.lookup_type(t_node.value)) is not None) and is_dataclass(dc):
975+
dc_types = get_type_hints(dc)
976+
if node.attr not in dc_types:
977+
raise ValueError(f"Key {node.attr} not found in dataclass/dictionary {dc}")
978+
self._found_types[node] = dc_types[node.attr]
961979
return t_node
962980

963981
tt = type_transformer(o_stream)

tests/test_type_based_replacement.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import ast
22
import copy
3+
import inspect
34
import logging
5+
from inspect import isclass
46
from typing import Any, Callable, Iterable, Optional, Tuple, Type, TypeVar, cast
57

68
import pytest
@@ -14,6 +16,7 @@
1416
remap_by_types,
1517
remap_from_lambda,
1618
)
19+
from func_adl.util_types import is_iterable, unwrap_iterable
1720

1821

1922
class Track:
@@ -504,6 +507,23 @@ def test_collection_Select(caplog):
504507

505508

506509
def test_dictionary():
510+
"Make sure that dictionaries turn into named types"
511+
512+
s = ast_lambda("{'jets': e.Jets()}")
513+
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
514+
515+
new_objs, new_s, expr_type = remap_by_types(objs, "e", Event, s)
516+
517+
# Fix to look for the named class with the correct types.
518+
assert isclass(expr_type)
519+
sig = inspect.signature(expr_type.__init__)
520+
assert len(sig.parameters) == 2
521+
assert "jets" in sig.parameters
522+
j_info = sig.parameters["jets"]
523+
assert str(j_info.annotation) == "typing.Iterable[tests.test_type_based_replacement.Jet]"
524+
525+
526+
def test_dictionary_sequence():
507527
"Check that we can type-follow through dictionaries"
508528

509529
s = ast_lambda("{'jets': e.Jets()}.jets.Select(lambda j: j.pt())")
@@ -526,8 +546,40 @@ def test_dictionary_bad_key():
526546
assert "jetsss" in str(e)
527547

528548

549+
def test_dictionary_through_Select():
550+
"""Make sure the Select statement carries the typing all the way through"""
551+
552+
s = ast_lambda("e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()})")
553+
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
554+
555+
_, _, expr_type = remap_by_types(objs, "e", Event, s)
556+
557+
assert is_iterable(expr_type)
558+
obj_itr = unwrap_iterable(expr_type)
559+
assert isclass(obj_itr)
560+
sig = inspect.signature(obj_itr.__init__)
561+
assert len(sig.parameters) == 3
562+
assert "pt" in sig.parameters
563+
j_info = sig.parameters["pt"]
564+
assert j_info.annotation == float
565+
566+
567+
def test_dictionary_through_Select_reference():
568+
"""Make sure the Select statement carries the typing all the way through,
569+
including a later reference"""
570+
571+
s = ast_lambda(
572+
"e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()}).Select(lambda info: info.pt)"
573+
)
574+
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))
575+
576+
_, _, expr_type = remap_by_types(objs, "e", Event, s)
577+
578+
assert expr_type == Iterable[float]
579+
580+
529581
def test_indexed_tuple():
530-
"Check that we can type-follow through dictionaries"
582+
"Check that we can type-follow through tuples"
531583

532584
s = ast_lambda("(e.Jets(),)[0].Select(lambda j: j.pt())")
533585
objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load()))

0 commit comments

Comments
 (0)