Skip to content

Commit aa36e9a

Browse files
fix issue#3269: unwrap tensor shape without opt val (#3279)
1 parent 8e2c82d commit aa36e9a

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

py/torch_tensorrt/dynamo/partitioning/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def construct_dynamic_input(
3131
if isinstance(dim, torch.SymInt):
3232
min_max_opt = extract_var_range_info(dim)
3333
min_shape.append(min_max_opt["min"])
34-
opt_shape.append(min_max_opt["opt"])
34+
# opt might not exist
35+
opt_shape.append(min_max_opt.get("opt"))
3536
max_shape.append(min_max_opt["max"])
3637
else:
3738
min_shape.append(dim)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
77

88
import numpy as np
9+
import sympy
910
import tensorrt as trt
1011
import torch
1112
from torch._subclasses.fake_tensor import FakeTensor
@@ -342,14 +343,14 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, int]:
342343
shape_env.var_to_val
343344
)
344345
assert var_range, var_val
345-
min_val, max_val, opt_val = int(var_range.lower), int(var_range.upper), int(var_val)
346+
min_val, max_val = int(var_range.lower), int(var_range.upper)
346347
# Torchdynamo 0/1 specialization outlier
347348
min_val = 1 if min_val == 2 else min_val
348349
min_max_opt = {}
349350
min_max_opt["min"] = min_val
350351
min_max_opt["max"] = max_val
351-
min_max_opt["opt"] = opt_val
352-
352+
if isinstance(var_val, sympy.core.numbers.Integer):
353+
min_max_opt["opt"] = int(var_val)
353354
return min_max_opt
354355

355356

0 commit comments

Comments
 (0)