Skip to content

Commit 1cf8daf

Browse files
authored
[FRONTEND] Remove dependency on torch (#7519)
A dependency on torch was added by mistake.
1 parent df39911 commit 1cf8daf

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/triton/tools/tensor_descriptor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from dataclasses import dataclass
22
from typing import List, Any
33
from triton._utils import validate_block_shape
4-
from torch._subclasses.fake_tensor import FakeTensor
5-
from torch._subclasses.functional_tensor import FunctionalTensor
64

75

86
@dataclass
@@ -18,7 +16,9 @@ def __post_init__(self):
1816
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
1917
assert rank > 0, "rank must not be zero"
2018
assert rank <= 5, "rank cannot be more than 5"
21-
if not isinstance(self.base, (FakeTensor, FunctionalTensor)):
19+
ty = type(self.base)
20+
type_name = f"{ty.__module__}.{ty.__name__}"
21+
if type_name not in ("torch.FakeTensor", "torch.FunctionalTensor"):
2222
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
2323
validate_block_shape(self.block_shape)
2424
elem_bytes = self.base.dtype.itemsize

0 commit comments

Comments
 (0)