@@ -611,7 +611,8 @@ def _check_inputs_shape(
611
611
elif isinstance (input1 , dict ):
612
612
if input1 .keys () != input2 .keys ():
613
613
return False
614
- for (ka , va ), vb in zip (input1 .items (), input2 .values ()):
614
+ for ka , va in input1 .items ():
615
+ vb = input2 [ka ]
615
616
if type (va ) != type (vb ):
616
617
return False
617
618
if isinstance (va , bool ) and va != vb :
@@ -638,9 +639,9 @@ def _check_inputs_shape(
638
639
639
640
@staticmethod
640
641
def _check_tensor_shapes_with_dynamic_shapes (
641
- t1 : torch .tensor , t2 : torch .tensor , dynamic_shape : dict [int , Any ]
642
+ input_1 : torch .tensor , input_2 : torch .tensor , dynamic_shape : dict [int , Any ]
642
643
) -> bool :
643
- for (i , axis_0 ), axis_1 in zip (enumerate (t1 .shape ), t2 .shape ):
644
+ for (i , axis_0 ), axis_1 in zip (enumerate (input_1 .shape ), input_2 .shape ):
644
645
if axis_0 != axis_1 :
645
646
if i not in dynamic_shape :
646
647
logger .warning (
@@ -650,7 +651,7 @@ def _check_tensor_shapes_with_dynamic_shapes(
650
651
dyn = dynamic_shape [i ]
651
652
if axis_1 > dyn .max or axis_1 < dyn .min :
652
653
raise DynamicShapeOutOfRangeException (
653
- f"The input size ( { axis_1 } ) of dimension ( { i } ) is not in dynamic shape range [{ dyn .max } , { dyn .max } ]! "
654
+ f"Dimension ( { i } ) of new input tensor is not the range of supported shapes (saw: ( { axis_1 } ), expected: [{ dyn .min } , { dyn .max } ]) "
654
655
)
655
656
656
657
return True
0 commit comments