13
13
# limitations under the License.
14
14
15
15
import math
16
+ import os
16
17
from typing import Optional , Tuple
17
18
18
19
import numpy
19
20
import torch
21
+ from safetensors import safe_open
22
+
23
+
24
+ REPO_PATH = os .path .join (os .path .dirname (__file__ ), "hadamards.safetensors" )
20
25
21
26
22
27
__all__ = ["random_hadamard_matrix" , "deterministic_hadamard_matrix" ]
23
28
24
- # adapted from:
25
- # https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py
29
+
30
+ # note that hadamard matrix multiplication can be accelerated using a library such as
31
+ # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
32
+
33
+
26
34
def deterministic_hadamard_matrix (size : int ) -> torch .Tensor :
27
35
"""
28
36
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
29
37
`n` must be a power of 2.
30
38
39
+ Adapated from https://github.com/scipy/scipy/blob/v1.15.2/scipy/linalg/_special_matrices.py # noqa: E501
40
+
31
41
:param size: order of the matrix, must be a power of 2
32
42
:return: hadamard matrix of size `size`
33
43
"""
@@ -41,20 +51,12 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
41
51
H = numpy .array ([[1 ]], dtype = int )
42
52
43
53
# Sylvester's construction
44
- for i in range (0 , log2 ):
54
+ for _ in range (0 , log2 ):
45
55
H = numpy .vstack ((numpy .hstack ((H , H )), numpy .hstack ((H , - H ))))
46
56
47
57
return torch .from_numpy (H / math .sqrt (size ))
48
58
49
59
50
- # adapted from:
51
- # https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py
52
-
53
- # TODO: the following library exists for online rotations and should be considered
54
- # in the future:
55
- # https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
56
-
57
-
58
60
def random_hadamard_matrix (
59
61
size : int , gen : Optional [torch .Generator ] = None
60
62
) -> torch .Tensor :
@@ -63,6 +65,8 @@ def random_hadamard_matrix(
63
65
See https://cornell-relaxml.github.io/quip-sharp/ ,
64
66
Section "Randomized Hadamard Transformation"
65
67
68
+ Adapated from https://github.com/facebookresearch/SpinQuant/blob/main/utils/hadamard_utils.py # noqa: E501
69
+
66
70
:param size: The dimension of the hamadard matrix
67
71
:param gen: Optional generator random values
68
72
:return: randomly generated hadamard matrix
@@ -74,31 +78,39 @@ def random_hadamard_matrix(
74
78
return _matmul_hadU (Q ) / math .sqrt (size )
75
79
76
80
77
- def _get_hadK (n : int ) -> Tuple [torch .Tensor , int ]:
78
- import os
81
+ def _get_known_hadamard (n : int , file_path : str = REPO_PATH ) -> Optional [torch .Tensor ]:
82
+ """
83
+ Fetch a known hadamard matrix of size `n` from hadamard repo path if it exists
79
84
80
- from safetensors import safe_open
85
+ Note: This function reopens the safetensors file every time it is called.
86
+ This is inefficient, but inconsequential because hadamards are typically
87
+ cached by size through the factory that produced them. This is also simpler
88
+ than forcing callers to manage the file open context
81
89
82
- file_path = os .path .join (os .path .dirname (__file__ ), "hadamards.safetensors" )
90
+ :param n: size of known hadamard matrix
91
+ :return: a known hadamard matrix of size `n` if one exists, else None
92
+ """
83
93
with safe_open (file_path , framework = "pt" , device = "cpu" ) as file :
84
94
for divisor in file .keys ():
85
95
if n % int (divisor ) == 0 :
86
- return file .get_tensor (str (divisor )), int (divisor )
96
+ return file .get_tensor (divisor )
97
+
98
+ return None
87
99
88
- else :
89
- assert _is_pow2 (n )
90
- return None , 1
91
100
101
+ def _matmul_hadU (X : torch .Tensor ) -> torch .Tensor :
102
+ size = X .shape [- 1 ]
92
103
93
- def _matmul_hadU (X ) -> torch .Tensor :
94
- n = X .shape [- 1 ]
95
104
# Check if we have the determined hadamard matrix
96
- hadK , K = _get_hadK (n )
105
+ hadK = _get_known_hadamard (size )
106
+ K = hadK .size (0 ) if hadK is not None else 1
107
+ if hadK is None and not _is_pow2 (size ):
108
+ raise ValueError (f"Cannot construct random hadamard matrix of size { size } " )
109
+
110
+ # For cases when hadK is not predetermined, determine hadamard matrix
97
111
# Reshape diag matrix with randomized -1/+1
98
- input = X .clone ().view (- 1 , n , 1 )
112
+ input = X .clone ().view (- 1 , size , 1 )
99
113
output = input .clone ()
100
-
101
- # for cases when hadK is not predetermined, determine hadamard matrix
102
114
while input .shape [1 ] > K :
103
115
input = input .view (input .shape [0 ], input .shape [1 ] // 2 , 2 , input .shape [2 ])
104
116
output = output .view (input .shape )
0 commit comments