File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change 1
1
from dataclasses import dataclass
2
2
from typing import List , Any
3
3
from triton ._utils import validate_block_shape
4
- from torch ._subclasses .fake_tensor import FakeTensor
5
- from torch ._subclasses .functional_tensor import FunctionalTensor
6
4
7
5
8
6
@dataclass
@@ -18,7 +16,9 @@ def __post_init__(self):
18
16
assert len (self .block_shape ) == rank , f"rank mismatch: { self } "
19
17
assert rank > 0 , "rank must not be zero"
20
18
assert rank <= 5 , "rank cannot be more than 5"
21
- if not isinstance (self .base , (FakeTensor , FunctionalTensor )):
19
+ ty = type (self .base )
20
+ type_name = f"{ ty .__module__ } .{ ty .__name__ } "
21
+ if type_name not in ("torch.FakeTensor" , "torch.FunctionalTensor" ):
22
22
assert self .base .data_ptr () % 16 == 0 , "base must be 16-byte aligned"
23
23
validate_block_shape (self .block_shape )
24
24
elem_bytes = self .base .dtype .itemsize
You can’t perform that action at this time.
0 commit comments