Skip to content

Commit 34f1310

Browse files
authored
fix: Fix a bug with dynamic shape validation in MTMM (#3837)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 2de0ec4 commit 34f1310

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,8 @@ def _check_inputs_shape(
611611
elif isinstance(input1, dict):
612612
if input1.keys() != input2.keys():
613613
return False
614-
for (ka, va), vb in zip(input1.items(), input2.values()):
614+
for ka, va in input1.items():
615+
vb = input2[ka]
615616
if type(va) != type(vb):
616617
return False
617618
if isinstance(va, bool) and va != vb:
@@ -638,9 +639,9 @@ def _check_inputs_shape(
638639

639640
@staticmethod
640641
def _check_tensor_shapes_with_dynamic_shapes(
641-
t1: torch.tensor, t2: torch.tensor, dynamic_shape: dict[int, Any]
642+
input_1: torch.tensor, input_2: torch.tensor, dynamic_shape: dict[int, Any]
642643
) -> bool:
643-
for (i, axis_0), axis_1 in zip(enumerate(t1.shape), t2.shape):
644+
for (i, axis_0), axis_1 in zip(enumerate(input_1.shape), input_2.shape):
644645
if axis_0 != axis_1:
645646
if i not in dynamic_shape:
646647
logger.warning(
@@ -650,7 +651,7 @@ def _check_tensor_shapes_with_dynamic_shapes(
650651
dyn = dynamic_shape[i]
651652
if axis_1 > dyn.max or axis_1 < dyn.min:
652653
raise DynamicShapeOutOfRangeException(
653-
f"The input size ({axis_1}) of dimension ({i}) is not in dynamic shape range [{dyn.max}, {dyn.max}]!"
654+
f"Dimension ({i}) of new input tensor is not the range of supported shapes (saw: ({axis_1}), expected: [{dyn.min}, {dyn.max}])"
654655
)
655656

656657
return True

0 commit comments

Comments
 (0)