Skip to content

Commit a27db62

Browse files
committed
use hadamards database file
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 4b81ac7 commit a27db62

File tree

3 files changed

+18
-68
lines changed

3 files changed

+18
-68
lines changed

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 18 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -74,29 +74,26 @@ def random_hadamard_matrix(
7474
return _matmul_hadU(Q) / math.sqrt(size)
7575

7676

77-
def _get_hadK(n: int, transpose: bool = False) -> Tuple[torch.Tensor, int]:
78-
# NOTE: we can easily extend the list of supported shapes/sizes
79-
# by adding to these methods
80-
hadK, K = None, None
81-
if n % 20 == 0:
82-
assert _is_pow2(n // 20)
83-
K = 20
84-
hadK = _get_had20().T if transpose else _get_had20()
85-
elif n % 12 == 0:
86-
assert _is_pow2(n // 12)
87-
K = 12
88-
hadK = _get_had12().T if transpose else _get_had12()
89-
else:
90-
assert _is_pow2(n)
91-
K = 1
92-
93-
return hadK, K
94-
95-
96-
def _matmul_hadU(X, transpose=False) -> torch.Tensor:
77+
def _get_hadK(n: int) -> Tuple[torch.Tensor, int]:
78+
import os
79+
80+
from safetensors import safe_open
81+
82+
file_path = os.path.join(os.path.dirname(__file__), "hadamards.safetensors")
83+
with safe_open(file_path, framework="pt", device="cpu") as file:
84+
for divisor in file.keys():
85+
if n % int(divisor) == 0:
86+
return file.get_tensor(str(divisor)), int(divisor)
87+
88+
else:
89+
assert _is_pow2(n)
90+
return None, 1
91+
92+
93+
def _matmul_hadU(X) -> torch.Tensor:
9794
n = X.shape[-1]
9895
# Check if we have the determined hadamard matrix
99-
hadK, K = _get_hadK(n, transpose)
96+
hadK, K = _get_hadK(n)
10097
# Reshape diag matrix with randomized -1/+1
10198
input = X.clone().view(-1, n, 1)
10299
output = input.clone()
@@ -129,33 +126,3 @@ def _matmul_hadU(X, transpose=False) -> torch.Tensor:
129126

130127
def _is_pow2(n: int) -> bool:
131128
return (n & (n - 1) == 0) and (n > 0)
132-
133-
134-
def _reshape_bits(packed_bits: numpy.ndarray, original_size: int) -> numpy.ndarray:
135-
had_unpacked = numpy.unpackbits(packed_bits)
136-
had_unpacked = [1 if x == 1 else -1 for x in had_unpacked]
137-
had_unpacked = numpy.array(had_unpacked).reshape((original_size, original_size))
138-
return had_unpacked
139-
140-
141-
# http://www.neilsloane.com/hadamard/index.html
142-
def _get_had12() -> torch.Tensor:
143-
# fmt: off
144-
had_12 = numpy.array([128, 13, 29, 232, 235, 71, 218,
145-
62, 209, 246, 139, 180, 157, 168, 237, 199, 106, 59], dtype=numpy.uint8)
146-
# fmt: on
147-
# TODO: just unpack during apply
148-
had_12_unpacked = _reshape_bits(had_12, original_size=12)
149-
return torch.tensor(had_12_unpacked)
150-
151-
152-
def _get_had20() -> torch.Tensor:
153-
# fmt: off
154-
had_20 = numpy.array([128, 0, 13, 133, 121, 236, 43, 203, 97, 94, 155, 10, 252,
155-
216, 87, 230, 194, 191, 54, 21, 249, 176, 171, 205, 133, 222, 108, 42, 243,
156-
97, 215, 155, 10, 188, 216, 149, 230, 200, 175, 54, 133, 121, 188, 43,
157-
205, 225, 94, 107, 10, 243], dtype=numpy.uint8)
158-
# fmt: on
159-
# TODO: just unpack during apply
160-
had_20_unpacked = _reshape_bits(had_20, original_size=20)
161-
return torch.tensor(had_20_unpacked)
Binary file not shown.

tests/test_transform/utils/test_hadamard.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,11 @@
1717
import pytest
1818
import torch
1919
from compressed_tensors.transform.utils.hadamard import (
20-
_get_had12,
21-
_get_had20,
2220
deterministic_hadamard_matrix,
2321
random_hadamard_matrix,
2422
)
2523

2624

27-
@pytest.mark.parametrize(
28-
"had_func",
29-
[
30-
_get_had12,
31-
_get_had20,
32-
],
33-
)
34-
def test_packed_hadamard_compliant(had_func):
35-
had_matrix = had_func()
36-
size = had_matrix.size(0)
37-
# HH.T == nI
38-
product = had_matrix @ had_matrix.T
39-
assert torch.equal(product, size * torch.eye(size))
40-
41-
4225
@pytest.mark.parametrize(
4326
"size",
4427
[4096, 2048],

0 commit comments

Comments
 (0)