11import ast
22import copy
3+ import inspect
34import logging
5+ from inspect import isclass
46from typing import Any , Callable , Iterable , Optional , Tuple , Type , TypeVar , cast
57
68import pytest
1416 remap_by_types ,
1517 remap_from_lambda ,
1618)
19+ from func_adl .util_types import is_iterable , unwrap_iterable
1720
1821
1922class Track :
@@ -504,6 +507,23 @@ def test_collection_Select(caplog):
504507
505508
506509def test_dictionary ():
510+ "Make sure that dictionaries turn into named types"
511+
512+ s = ast_lambda ("{'jets': e.Jets()}" )
513+ objs = ObjectStream [Event ](ast .Name (id = "e" , ctx = ast .Load ()))
514+
515+ new_objs , new_s , expr_type = remap_by_types (objs , "e" , Event , s )
516+
517+ # Fix to look for the named class with the correct types.
518+ assert isclass (expr_type )
519+ sig = inspect .signature (expr_type .__init__ )
520+ assert len (sig .parameters ) == 2
521+ assert "jets" in sig .parameters
522+ j_info = sig .parameters ["jets" ]
523+ assert str (j_info .annotation ) == "typing.Iterable[tests.test_type_based_replacement.Jet]"
524+
525+
526+ def test_dictionary_sequence ():
507527 "Check that we can type-follow through dictionaries"
508528
509529 s = ast_lambda ("{'jets': e.Jets()}.jets.Select(lambda j: j.pt())" )
@@ -526,8 +546,40 @@ def test_dictionary_bad_key():
526546 assert "jetsss" in str (e )
527547
528548
549+ def test_dictionary_through_Select ():
550+ """Make sure the Select statement carries the typing all the way through"""
551+
552+ s = ast_lambda ("e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()})" )
553+ objs = ObjectStream [Event ](ast .Name (id = "e" , ctx = ast .Load ()))
554+
555+ _ , _ , expr_type = remap_by_types (objs , "e" , Event , s )
556+
557+ assert is_iterable (expr_type )
558+ obj_itr = unwrap_iterable (expr_type )
559+ assert isclass (obj_itr )
560+ sig = inspect .signature (obj_itr .__init__ )
561+ assert len (sig .parameters ) == 3
562+ assert "pt" in sig .parameters
563+ j_info = sig .parameters ["pt" ]
564+ assert j_info .annotation == float
565+
566+
567+ def test_dictionary_through_Select_reference ():
568+ """Make sure the Select statement carries the typing all the way through,
569+ including a later reference"""
570+
571+ s = ast_lambda (
572+ "e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()}).Select(lambda info: info.pt)"
573+ )
574+ objs = ObjectStream [Event ](ast .Name (id = "e" , ctx = ast .Load ()))
575+
576+ _ , _ , expr_type = remap_by_types (objs , "e" , Event , s )
577+
578+ assert expr_type == Iterable [float ]
579+
580+
529581def test_indexed_tuple ():
530- "Check that we can type-follow through dictionaries "
582+ "Check that we can type-follow through tuples "
531583
532584 s = ast_lambda ("(e.Jets(),)[0].Select(lambda j: j.pt())" )
533585 objs = ObjectStream [Event ](ast .Name (id = "e" , ctx = ast .Load ()))
0 commit comments