File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change 11from dataclasses import dataclass
22from typing import List , Any
33from triton ._utils import validate_block_shape
4- from torch ._subclasses .fake_tensor import FakeTensor
5- from torch ._subclasses .functional_tensor import FunctionalTensor
64
75
86@dataclass
@@ -18,7 +16,9 @@ def __post_init__(self):
1816 assert len (self .block_shape ) == rank , f"rank mismatch: { self } "
1917 assert rank > 0 , "rank must not be zero"
2018 assert rank <= 5 , "rank cannot be more than 5"
21- if not isinstance (self .base , (FakeTensor , FunctionalTensor )):
19+ ty = type (self .base )
20+ type_name = f"{ ty .__module__ } .{ ty .__name__ } "
21+ if type_name not in ("torch.FakeTensor" , "torch.FunctionalTensor" ):
2222 assert self .base .data_ptr () % 16 == 0 , "base must be 16-byte aligned"
2323 validate_block_shape (self .block_shape )
2424 elem_bytes = self .base .dtype .itemsize
You can’t perform that action at this time.
0 commit comments