@@ -49,17 +49,18 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
4949 The converted TVM tensor.
5050 """
5151 # Fix for Issue #18407: Handle FakeTensor and lifted tensors (from torch.export)
52- # Check if this is a FakeTensor or tensor subclass that doesn't support .numpy()
52+ # FakeTensor is an internal PyTorch API that may not be available in all versions
5353 try :
54- # Check if it's a FakeTensor
55- if hasattr ( torch , '_subclasses' ) and hasattr ( torch ._subclasses , ' fake_tensor' ):
56- if isinstance (tensor_value , torch . _subclasses . fake_tensor . FakeTensor ):
57- # Create a real tensor with the same shape and dtype
58- real_tensor = torch .zeros (tensor_value .shape , dtype = tensor_value .dtype )
59- return tvm .runtime .tensor (real_tensor .numpy ())
54+ # Check if it's a FakeTensor from torch._subclasses, which is an internal API
55+ FakeTensor = torch ._subclasses . fake_tensor . FakeTensor
56+ if isinstance (tensor_value , FakeTensor ):
57+ # Create a real tensor with the same shape and dtype as a placeholder
58+ real_tensor = torch .zeros (tensor_value .shape , dtype = tensor_value .dtype )
59+ return tvm .runtime .tensor (real_tensor .numpy ())
6060 except (AttributeError , ImportError ):
61+ # FakeTensor class might not exist in this torch version, or other import issue
6162 pass
62-
63+
6364 # PyTorch sparse tensors (layout != torch.strided) must be converted to dense.
6465 if tensor_value .layout != torch .strided :
6566 tensor_to_convert = tensor_value .to_dense ()
@@ -78,8 +79,12 @@ def _convert_pytorch_tensor_to_tvm(tensor_value: torch.Tensor) -> tvm.runtime.Te
7879 return tvm .runtime .tensor (tensor_cpu .numpy ())
7980 except RuntimeError as e :
8081 # Fix for Issue #18407: Handle tensor subclasses that don't support .numpy()
81- # This can happen with lifted tensors from torch.export
82- if "tensor subclasses" in str (e ) or "FakeTensor" in str (e ):
82+ # This can happen with lifted tensors from torch.export that slip through
83+ # the FakeTensor check above (e.g., other tensor subclasses)
84+ # String matching is fragile but necessary as PyTorch doesn't provide
85+ # a specific exception type for this case
86+ error_msg = str (e )
87+ if "tensor subclasses" in error_msg or "Cannot access data pointer" in error_msg :
8388 # Create a dummy tensor with the same shape and dtype
8489 dummy_tensor = torch .zeros (tensor_value .shape , dtype = tensor_value .dtype )
8590 return tvm .runtime .tensor (dummy_tensor .numpy ())
0 commit comments