1313# limitations under the License.
1414
1515import unittest
16+ from collections .abc import MutableMapping
17+ from itertools import chain
1618from typing import Any , Optional
1719
1820import celpy
4547]
4648
4749
48- def read_textproto ( ) -> simple_pb2 .SimpleTestFile :
50+ def load_test_data ( file_name : str ) -> simple_pb2 .SimpleTestFile :
4951 msg = simple_pb2 .SimpleTestFile ()
50- with open (f"tests/testdata/string_ext_ { CEL_SPEC_VERSION } .textproto" ) as file :
52+ with open (file_name ) as file :
5153 text_data = file .read ()
5254 text_format .Parse (text_data , msg )
5355 return msg
5456
5557
56- def build_binding (bindings : dict [str , eval_pb2 .ExprValue ]) -> dict [Any , Any ]:
58+ def build_variables (bindings : MutableMapping [str , eval_pb2 .ExprValue ]) -> dict [Any , Any ]:
5759 binder = {}
5860 for key , value in bindings .items ():
5961 if value .HasField ("value" ):
@@ -82,25 +84,33 @@ def get_eval_error_message(test: simple_pb2.SimpleTest) -> Optional[str]:
8284class TestFormat (unittest .TestCase ):
8385 @classmethod
8486 def setUpClass (cls ):
85- test_data = read_textproto ()
86- cls ._format_test_section = next ((x for x in test_data .section if x .name == "format" ), None )
87- cls ._format_error_test_section = next ((x for x in test_data .section if x .name == "format_errors" ), None )
87+ # The test data from the cel-spec conformance tests
88+ cel_test_data = load_test_data (f"tests/testdata/string_ext_{ CEL_SPEC_VERSION } .textproto" )
89+ # Our supplemental tests of functionality not in the cel conformance file, but defined in the spec.
90+ supplemental_test_data = load_test_data ("tests/testdata/string_ext_supplemental.textproto" )
91+
92+ # Combine the test data from both files into one
93+ sections = cel_test_data .section
94+ sections .extend (supplemental_test_data .section )
95+
96+ # Find the format tests which test successful formatting
97+ cls ._format_tests = chain .from_iterable (x .test for x in sections if x .name == "format" )
98+ # Find the format error tests which test errors during formatting
99+ cls ._format_error_tests = chain .from_iterable (x .test for x in sections if x .name == "format_errors" )
100+
88101 cls ._env = celpy .Environment (runner_class = InterpretedRunner )
89102
90103 def test_format_successes (self ):
91104 """
92105 Tests success scenarios for string.format
93106 """
94- section = self ._format_test_section
95- if section is None :
96- return
97- for test in section .test :
107+ for test in self ._format_tests :
98108 if test .name in skipped_tests :
99109 continue
100110 ast = self ._env .compile (test .expr )
101111 prog = self ._env .program (ast , functions = extra_func .EXTRA_FUNCS )
102112
103- bindings = build_binding (test .bindings )
113+ bindings = build_variables (test .bindings )
104114 # Ideally we should use pytest parametrize instead of subtests, but
105115 # that would require refactoring other tests also.
106116 with self .subTest (test .name ):
@@ -118,16 +128,13 @@ def test_format_errors(self):
118128 """
119129 Tests error scenarios for string.format
120130 """
121- section = self ._format_error_test_section
122- if section is None :
123- return
124- for test in section .test :
131+ for test in self ._format_error_tests :
125132 if test .name in skipped_error_tests :
126133 continue
127134 ast = self ._env .compile (test .expr )
128135 prog = self ._env .program (ast , functions = extra_func .EXTRA_FUNCS )
129136
130- bindings = build_binding (test .bindings )
137+ bindings = build_variables (test .bindings )
131138 # Ideally we should use pytest parametrize instead of subtests, but
132139 # that would require refactoring other tests also.
133140 with self .subTest (test .name ):
0 commit comments