Skip to content

Commit a67b780

Browse files
committed
[Relax][Torch] Fix from_exported_program crash with FakeTensor and lifted tensors (#18407)
1 parent c75b5ac commit a67b780

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,18 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
4949
The converted TVM tensor.
5050
"""
5151
# Fix for Issue #18407: Handle FakeTensor and lifted tensors (from torch.export)
52-
# Check if this is a FakeTensor or tensor subclass that doesn't support .numpy()
52+
# FakeTensor is an internal PyTorch API that may not be available in all versions
5353
try:
54-
# Check if it's a FakeTensor
55-
if hasattr(torch, '_subclasses') and hasattr(torch._subclasses, 'fake_tensor'):
56-
if isinstance(tensor_value, torch._subclasses.fake_tensor.FakeTensor):
57-
# Create a real tensor with the same shape and dtype
58-
real_tensor = torch.zeros(tensor_value.shape, dtype=tensor_value.dtype)
59-
return tvm.runtime.tensor(real_tensor.numpy())
54+
# Check if it's a FakeTensor from torch._subclasses, which is an internal API
55+
FakeTensor = torch._subclasses.fake_tensor.FakeTensor
56+
if isinstance(tensor_value, FakeTensor):
57+
# Create a real tensor with the same shape and dtype as a placeholder
58+
real_tensor = torch.zeros(tensor_value.shape, dtype=tensor_value.dtype)
59+
return tvm.runtime.tensor(real_tensor.numpy())
6060
except (AttributeError, ImportError):
61+
# FakeTensor class might not exist in this torch version, or other import issue
6162
pass
62-
63+
6364
# PyTorch sparse tensors (layout != torch.strided) must be converted to dense.
6465
if tensor_value.layout != torch.strided:
6566
tensor_to_convert = tensor_value.to_dense()
@@ -78,8 +79,12 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
7879
return tvm.runtime.tensor(tensor_cpu.numpy())
7980
except RuntimeError as e:
8081
# Fix for Issue #18407: Handle tensor subclasses that don't support .numpy()
81-
# This can happen with lifted tensors from torch.export
82-
if "tensor subclasses" in str(e) or "FakeTensor" in str(e):
82+
# This can happen with lifted tensors from torch.export that slip through
83+
# the FakeTensor check above (e.g., other tensor subclasses)
84+
# String matching is fragile but necessary as PyTorch doesn't provide
85+
# a specific exception type for this case
86+
error_msg = str(e)
87+
if "tensor subclasses" in error_msg or "Cannot access data pointer" in error_msg:
8388
# Create a dummy tensor with the same shape and dtype
8489
dummy_tensor = torch.zeros(tensor_value.shape, dtype=tensor_value.dtype)
8590
return tvm.runtime.tensor(dummy_tensor.numpy())

0 commit comments

Comments
 (0)