Skip to content

Commit 9a0afd6

Browse files
authored
fix typo (#241)
1 parent 0ba31bf commit 9a0afd6

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ def from_compressed_data(
103103
:param bitmask: 2d bitmask of non-zero values
104104
:return: instantiated Sparse24BitMaskTensor
105105
"""
106-
if isinstance(shape, Tensor):
107-
shape = shape.tolist()
106+
if isinstance(shape, list):
107+
shape = torch.tensor(shape)
108+
if isinstance(shape, torch.Tensor):
109+
shape = shape.flatten().tolist()
108110
return Sparse24BitMaskTensor(
109111
shape=shape, compressed=compressed, bitmask=bitmask
110112
)

src/compressed_tensors/utils/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
301301

302302

303303
def unpack_bitmasks(
304-
packed_bitmasks: torch.Tensor, original_shape: torch.Size
304+
packed_bitmasks: torch.Tensor, original_shape: List[int]
305305
) -> torch.Tensor:
306306
"""
307307
Converts a bitmask tensor back to a bytemask tensor for use during decompression

0 commit comments

Comments
 (0)