Skip to content

Commit cde1066

Browse files
committed
remove numpy, add tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 4a84ce1 commit cde1066

File tree

2 files changed

+54
-43
lines changed

2 files changed

+54
-43
lines changed

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,17 @@
1414

1515
import math
1616
import os
17-
from typing import Optional, Tuple
17+
from typing import Optional
1818

19-
import numpy
2019
import torch
2120
from safetensors import safe_open
2221

2322

2423
REPO_PATH = os.path.join(os.path.dirname(__file__), "hadamards.safetensors")
24+
DTYPE = torch.int32
2525

2626

27-
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
27+
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
2828

2929

3030
# note that hadamard matrix multiplication can be accelerated using a library such as
@@ -48,13 +48,13 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
4848
if size != 2**log2:
4949
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
5050

51-
H = numpy.array([[1]], dtype=int)
51+
H = torch.tensor([[1]], dtype=DTYPE)
5252

5353
# Sylvester's construction
5454
for _ in range(0, log2):
55-
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
55+
H = torch.vstack((torch.hstack((H, H)), torch.hstack((H, -H))))
5656

57-
return torch.from_numpy(H / math.sqrt(size))
57+
return H / math.sqrt(size)
5858

5959

6060
def random_hadamard_matrix(
@@ -72,15 +72,21 @@ def random_hadamard_matrix(
7272
:return: randomly generated hadamard matrix
7373
"""
7474
# Benefits: support other shapes / non powers of 2, support randomization
75-
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=torch.float64)
75+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=DTYPE)
7676
Q = Q * 2 - 1
7777
Q = torch.diag(Q)
7878
return _matmul_hadU(Q) / math.sqrt(size)
7979

8080

81-
def _get_known_hadamard(n: int, file_path: str = REPO_PATH) -> Optional[torch.Tensor]:
81+
def is_pow2(n: int) -> bool:
82+
return (n & (n - 1) == 0) and (n > 0)
83+
84+
85+
def _get_known_divisor(n: int, file_path: str = REPO_PATH) -> Optional[torch.Tensor]:
8286
"""
83-
Fetch a known hadamard matrix of size `n` from hadamard repo path if it exists
87+
Fetch a known hadamard matrix from the given file path. The returned matrix will
88+
be of of size `k` such that `n` divides `d` and `n / d` is a power of two. Return
89+
None if no such matrix exists.
8490
8591
Note: This function reopens the safetensors file every time it is called.
8692
This is inefficient, but inconsequential because hadamards are typically
@@ -91,9 +97,10 @@ def _get_known_hadamard(n: int, file_path: str = REPO_PATH) -> Optional[torch.Te
9197
:return: a known hadamard matrix of size `n` if one exists, else None
9298
"""
9399
with safe_open(file_path, framework="pt", device="cpu") as file:
94-
for divisor in file.keys():
95-
if n % int(divisor) == 0:
96-
return file.get_tensor(divisor)
100+
divisors = sorted([int(key) for key in file.keys()], reverse=True)
101+
for divisor in divisors:
102+
if n % divisor == 0 and is_pow2(n // divisor):
103+
return file.get_tensor(str(divisor)).to(dtype=DTYPE)
97104

98105
return None
99106

@@ -102,12 +109,11 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
102109
size = X.shape[-1]
103110

104111
# Check if we have the determined hadamard matrix
105-
hadK = _get_known_hadamard(size)
106-
K = hadK.size(0) if hadK is not None else 1
107-
if hadK is None and not _is_pow2(size):
112+
hadK = _get_known_divisor(size)
113+
if hadK is None:
108114
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
115+
K = hadK.size(0)
109116

110-
# For cases when hadK is not predetermined, determine hadamard matrix
111117
# Reshape diag matrix with randomized -1/+1
112118
input = X.clone().view(-1, size, 1)
113119
output = input.clone()
@@ -120,21 +126,11 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
120126
(input, output) = (output, input)
121127
del output
122128

123-
# K == 1 when hadK is None; this happens when the size dim (n)
124-
# is not comaptible with any of the maintained hadamard matrices
125-
126-
if K > 1:
127-
# Do not explicitly repeat - OOM
128-
# input = torch.bmm(
129-
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
130-
# Use bcast instead
131-
132-
# for cases when hadK is pre-determined
133-
input = hadK.view(1, K, K).to(input) @ input
129+
# Do not explicitly repeat - OOM
130+
# input = torch.bmm(
131+
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
132+
# Use bcast instead
133+
input = hadK.view(1, K, K).to(input) @ input
134134

135135
# normalize
136136
return input.view(X.shape)
137-
138-
139-
def _is_pow2(n: int) -> bool:
140-
return (n & (n - 1) == 0) and (n > 0)

tests/test_transform/utils/test_hadamard.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,39 @@
1313
# limitations under the License.
1414

1515

16-
import numpy
1716
import pytest
1817
import torch
1918
from compressed_tensors.transform.utils.hadamard import (
2019
deterministic_hadamard_matrix,
20+
is_pow2,
2121
random_hadamard_matrix,
2222
)
2323

2424

25-
@pytest.mark.parametrize(
26-
"size",
27-
[4096, 2048],
28-
)
25+
_sizes_to_test = [
26+
768, # gpt2 small
27+
1024, # gpt2 medium
28+
1280, # qwen_2_5_vl vision
29+
1600, # gpt2 xl
30+
2048, # gpt3 small
31+
# 3584, # qwen_2_5_vl
32+
# 3840, # qwen_2_5_vl vision qkv
33+
# 4096, # llama3
34+
# 14336, # llama3 intermediate
35+
# 18944, # qwen_2_5_vl intermediate
36+
]
37+
38+
39+
@pytest.mark.parametrize("size", _sizes_to_test)
2940
def test_random_hadamard_matrix_compliant(size):
41+
# (H / sqrt(n))(H.T / sqrt(n)) == I
3042
had_matrix = random_hadamard_matrix(size)
3143
product = torch.round(had_matrix @ had_matrix.T)
32-
assert torch.equal(product, torch.eye(size))
44+
assert torch.allclose(product, torch.eye(size, dtype=product.dtype), atol=1e-5)
3345

3446

3547
def test_random_hadamard_generator():
48+
# check that generation is deterministic with a seed
3649
generator = torch.Generator().manual_seed(42)
3750
one = random_hadamard_matrix(2048, generator)
3851
two = random_hadamard_matrix(2048, generator)
@@ -56,12 +69,14 @@ def test_random_hadamard_generator():
5669
assert torch.all(two[:3, :3].sign() == two_true.sign())
5770

5871

59-
@pytest.mark.parametrize(
60-
"size",
61-
[1024],
62-
)
72+
@pytest.mark.parametrize("size", _sizes_to_test)
6373
def test_deterministic_hadamard_compliant(size):
64-
had_matrix = deterministic_hadamard_matrix(size)
74+
if not is_pow2(size):
75+
with pytest.raises(ValueError):
76+
had_matrix = deterministic_hadamard_matrix(size)
77+
return
78+
6579
# (H / sqrt(n))(H.T / sqrt(n)) == I
80+
had_matrix = deterministic_hadamard_matrix(size)
6681
product = had_matrix @ had_matrix.T
67-
assert numpy.array_equal(product, numpy.eye(size))
82+
assert torch.allclose(product, torch.eye(size, dtype=product.dtype), atol=1e-5)

0 commit comments

Comments
 (0)