@@ -60,7 +60,7 @@ def test_model(
6060 )
6161
6262
63- def _validate_input_shape (shape : Tuple [int , ...], shape_spec ) -> bool :
63+ def check_input_shape (shape : Tuple [int , ...], shape_spec ) -> bool :
6464 if isinstance (shape_spec , list ):
6565 if shape != tuple (shape_spec ):
6666 return False
@@ -81,7 +81,7 @@ def _validate_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
8181 return True
8282
8383
84- def _validate_output_shape (shape : Tuple [int , ...], shape_spec , input_shapes ) -> bool :
84+ def check_output_shape (shape : Tuple [int , ...], shape_spec , input_shapes ) -> bool :
8585 if isinstance (shape_spec , list ):
8686 return shape == tuple (shape_spec )
8787 elif isinstance (shape_spec , ImplicitOutputShape ):
@@ -129,7 +129,7 @@ def test_resource(
129129 assert len (inputs ) == len (model .inputs ) # should be checked by validation
130130 input_shapes = {}
131131 for idx , (ipt , ipt_spec ) in enumerate (zip (inputs , model .inputs )):
132- if not _validate_input_shape (tuple (ipt .shape ), ipt_spec .shape ):
132+ if not check_input_shape (tuple (ipt .shape ), ipt_spec .shape ):
133133 raise ValidationError (
134134 f"Shape { tuple (ipt .shape )} of test input { idx } '{ ipt_spec .name } ' does not match "
135135 f"input shape description: { ipt_spec .shape } ."
@@ -138,7 +138,7 @@ def test_resource(
138138
139139 assert len (expected ) == len (model .outputs ) # should be checked by validation
140140 for idx , (out , out_spec ) in enumerate (zip (expected , model .outputs )):
141- if not _validate_output_shape (tuple (out .shape ), out_spec .shape , input_shapes ):
141+ if not check_output_shape (tuple (out .shape ), out_spec .shape , input_shapes ):
142142 error = (error or "" ) + (
143143 f"Shape { tuple (out .shape )} of test output { idx } '{ out_spec .name } ' does not match "
144144 f"output shape description: { out_spec .shape } ."
0 commit comments