Skip to content

Commit f061db9

Browse files
committed
add docstrings, cleanup
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 67675c3 commit f061db9

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,31 @@
1313
# limitations under the License.
1414

1515
import math
16+
import os
1617
from typing import Optional, Tuple
1718

1819
import numpy
1920
import torch
21+
from safetensors import safe_open
22+
23+
24+
REPO_PATH = os.path.join(os.path.dirname(__file__), "hadamards.safetensors")
2025

2126

2227
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix"]
2328

24-
# adapted from:
25-
# https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
29+
30+
# note that hadamard matrix multiplication can be accelerated using a library such as
31+
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
32+
33+
2634
def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
2735
"""
2836
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
2937
`n` must be a power of 2.
3038
39+
Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
40+
3141
:param size: order of the matrix, must be a power of 2
3242
:return: hadamard matrix of size `size`
3343
"""
@@ -41,20 +51,12 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
4151
H = numpy.array([[1]], dtype=int)
4252

4353
# Sylvester's construction
44-
for i in range(0, log2):
54+
for _ in range(0, log2):
4555
H = numpy.vstack((numpy.hstack((H, H)), numpy.hstack((H, -H))))
4656

4757
return torch.from_numpy(H / math.sqrt(size))
4858

4959

50-
# adapted from:
51-
# https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
52-
53-
# TODO: the following library exists for online rotations and should be considered
54-
# in the future:
55-
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
56-
57-
5860
def random_hadamard_matrix(
5961
size: int, gen: Optional[torch.Generator] = None
6062
) -> torch.Tensor:
@@ -63,6 +65,8 @@ def random_hadamard_matrix(
6365
See https://cornell-relaxml.github.io/quip-sharp/ ,
6466
Section "Randomized Hadamard Transformation"
6567
68+
Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
69+
6670
:param size: The dimension of the hamadard matrix
6771
:param gen: Optional generator random values
6872
:return: randomly generated hadamard matrix
@@ -74,31 +78,39 @@ def random_hadamard_matrix(
7478
return _matmul_hadU(Q) / math.sqrt(size)
7579

7680

77-
def _get_hadK(n: int) -> Tuple[torch.Tensor, int]:
78-
import os
81+
def _get_known_hadamard(n: int, file_path: str = REPO_PATH) -> Optional[torch.Tensor]:
82+
"""
83+
Fetch a known hadamard matrix of size `n` from hadamard repo path if it exists
7984
80-
from safetensors import safe_open
85+
Note: This function reopens the safetensors file every time it is called.
86+
This is inefficient, but inconsequential because hadamards are typically
87+
cached by size through the factory that produced them. This is also simpler
88+
than forcing callers to manage the file open context
8189
82-
file_path = os.path.join(os.path.dirname(__file__), "hadamards.safetensors")
90+
:param n: size of known hadamard matrix
91+
:return: a known hadamard matrix of size `n` if one exists, else None
92+
"""
8393
with safe_open(file_path, framework="pt", device="cpu") as file:
8494
for divisor in file.keys():
8595
if n % int(divisor) == 0:
86-
return file.get_tensor(str(divisor)), int(divisor)
96+
return file.get_tensor(divisor)
97+
98+
return None
8799

88-
else:
89-
assert _is_pow2(n)
90-
return None, 1
91100

101+
def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
102+
size = X.shape[-1]
92103

93-
def _matmul_hadU(X) -> torch.Tensor:
94-
n = X.shape[-1]
95104
# Check if we have the determined hadamard matrix
96-
hadK, K = _get_hadK(n)
105+
hadK = _get_known_hadamard(size)
106+
K = hadK.size(0) if hadK is not None else 1
107+
if hadK is None and not _is_pow2(size):
108+
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
109+
110+
# For cases when hadK is not predetermined, determine hadamard matrix
97111
# Reshape diag matrix with randomized -1/+1
98-
input = X.clone().view(-1, n, 1)
112+
input = X.clone().view(-1, size, 1)
99113
output = input.clone()
100-
101-
# for cases when hadK is not predetermined, determine hadamard matrix
102114
while input.shape[1] > K:
103115
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
104116
output = output.view(input.shape)

0 commit comments

Comments
 (0)