Skip to content

Commit 3772dbd

Browse files
authored
[FRONTEND] generalize a bit the check for faketensors (#7613)
1 parent ddaa11f commit 3772dbd

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

python/triton/tools/tensor_descriptor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ def __post_init__(self):
1717
assert rank > 0, "rank must not be zero"
1818
assert rank <= 5, "rank cannot be more than 5"
1919
ty = type(self.base)
20-
type_name = f"{ty.__module__}.{ty.__name__}"
21-
if type_name not in ("torch.FakeTensor", "torch.FunctionalTensor"):
20+
if ty.__name__ not in ("FakeTensor", "FunctionalTensor"):
2221
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
2322
validate_block_shape(self.block_shape)
2423
elem_bytes = self.base.dtype.itemsize

0 commit comments

Comments
 (0)