@@ -74,29 +74,26 @@ def random_hadamard_matrix(
74
74
return _matmul_hadU (Q ) / math .sqrt (size )
75
75
76
76
77
- def _get_hadK (n : int , transpose : bool = False ) -> Tuple [torch .Tensor , int ]:
78
- # NOTE: we can easily extend the list of supported shapes/sizes
79
- # by adding to these methods
80
- hadK , K = None , None
81
- if n % 20 == 0 :
82
- assert _is_pow2 (n // 20 )
83
- K = 20
84
- hadK = _get_had20 ().T if transpose else _get_had20 ()
85
- elif n % 12 == 0 :
86
- assert _is_pow2 (n // 12 )
87
- K = 12
88
- hadK = _get_had12 ().T if transpose else _get_had12 ()
89
- else :
90
- assert _is_pow2 (n )
91
- K = 1
92
-
93
- return hadK , K
94
-
95
-
96
- def _matmul_hadU (X , transpose = False ) -> torch .Tensor :
77
+ def _get_hadK (n : int ) -> Tuple [torch .Tensor , int ]:
78
+ import os
79
+
80
+ from safetensors import safe_open
81
+
82
+ file_path = os .path .join (os .path .dirname (__file__ ), "hadamards.safetensors" )
83
+ with safe_open (file_path , framework = "pt" , device = "cpu" ) as file :
84
+ for divisor in file .keys ():
85
+ if n % int (divisor ) == 0 :
86
+ return file .get_tensor (str (divisor )), int (divisor )
87
+
88
+ else :
89
+ assert _is_pow2 (n )
90
+ return None , 1
91
+
92
+
93
+ def _matmul_hadU (X ) -> torch .Tensor :
97
94
n = X .shape [- 1 ]
98
95
# Check if we have the determined hadamard matrix
99
- hadK , K = _get_hadK (n , transpose )
96
+ hadK , K = _get_hadK (n )
100
97
# Reshape diag matrix with randomized -1/+1
101
98
input = X .clone ().view (- 1 , n , 1 )
102
99
output = input .clone ()
@@ -129,33 +126,3 @@ def _matmul_hadU(X, transpose=False) -> torch.Tensor:
129
126
130
127
def _is_pow2 (n : int ) -> bool :
131
128
return (n & (n - 1 ) == 0 ) and (n > 0 )
132
-
133
-
134
- def _reshape_bits (packed_bits : numpy .ndarray , original_size : int ) -> numpy .ndarray :
135
- had_unpacked = numpy .unpackbits (packed_bits )
136
- had_unpacked = [1 if x == 1 else - 1 for x in had_unpacked ]
137
- had_unpacked = numpy .array (had_unpacked ).reshape ((original_size , original_size ))
138
- return had_unpacked
139
-
140
-
141
- # http://www.neilsloane.com/hadamard/index.html
142
- def _get_had12 () -> torch .Tensor :
143
- # fmt: off
144
- had_12 = numpy .array ([128 , 13 , 29 , 232 , 235 , 71 , 218 ,
145
- 62 , 209 , 246 , 139 , 180 , 157 , 168 , 237 , 199 , 106 , 59 ], dtype = numpy .uint8 )
146
- # fmt: on
147
- # TODO: just unpack during apply
148
- had_12_unpacked = _reshape_bits (had_12 , original_size = 12 )
149
- return torch .tensor (had_12_unpacked )
150
-
151
-
152
- def _get_had20 () -> torch .Tensor :
153
- # fmt: off
154
- had_20 = numpy .array ([128 , 0 , 13 , 133 , 121 , 236 , 43 , 203 , 97 , 94 , 155 , 10 , 252 ,
155
- 216 , 87 , 230 , 194 , 191 , 54 , 21 , 249 , 176 , 171 , 205 , 133 , 222 , 108 , 42 , 243 ,
156
- 97 , 215 , 155 , 10 , 188 , 216 , 149 , 230 , 200 , 175 , 54 , 133 , 121 , 188 , 43 ,
157
- 205 , 225 , 94 , 107 , 10 , 243 ], dtype = numpy .uint8 )
158
- # fmt: on
159
- # TODO: just unpack during apply
160
- had_20_unpacked = _reshape_bits (had_20 , original_size = 20 )
161
- return torch .tensor (had_20_unpacked )
0 commit comments