14
14
15
15
import math
16
16
import os
17
- from typing import Optional , Tuple
17
+ from typing import Optional
18
18
19
- import numpy
20
19
import torch
21
20
from safetensors import safe_open
22
21
23
22
24
23
REPO_PATH = os .path .join (os .path .dirname (__file__ ), "hadamards.safetensors" )
24
+ DTYPE = torch .int32
25
25
26
26
27
- __all__ = ["random_hadamard_matrix" , "deterministic_hadamard_matrix" ]
27
+ __all__ = ["random_hadamard_matrix" , "deterministic_hadamard_matrix" , "is_pow2" ]
28
28
29
29
30
30
# note that hadamard matrix multiplication can be accelerated using a library such as
@@ -48,13 +48,13 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
48
48
if size != 2 ** log2 :
49
49
raise ValueError ("Cannot construct deterministic hadamard of size != 2^n" )
50
50
51
- H = numpy . array ([[1 ]], dtype = int )
51
+ H = torch . tensor ([[1 ]], dtype = DTYPE )
52
52
53
53
# Sylvester's construction
54
54
for _ in range (0 , log2 ):
55
- H = numpy .vstack ((numpy .hstack ((H , H )), numpy .hstack ((H , - H ))))
55
+ H = torch .vstack ((torch .hstack ((H , H )), torch .hstack ((H , - H ))))
56
56
57
- return torch . from_numpy ( H / math .sqrt (size ) )
57
+ return H / math .sqrt (size )
58
58
59
59
60
60
def random_hadamard_matrix (
@@ -72,15 +72,21 @@ def random_hadamard_matrix(
72
72
:return: randomly generated hadamard matrix
73
73
"""
74
74
# Benefits: support other shapes / non powers of 2, support randomization
75
- Q = torch .randint (low = 0 , high = 2 , size = (size ,), generator = gen , dtype = torch . float64 )
75
+ Q = torch .randint (low = 0 , high = 2 , size = (size ,), generator = gen , dtype = DTYPE )
76
76
Q = Q * 2 - 1
77
77
Q = torch .diag (Q )
78
78
return _matmul_hadU (Q ) / math .sqrt (size )
79
79
80
80
81
- def _get_known_hadamard (n : int , file_path : str = REPO_PATH ) -> Optional [torch .Tensor ]:
81
+ def is_pow2 (n : int ) -> bool :
82
+ return (n & (n - 1 ) == 0 ) and (n > 0 )
83
+
84
+
85
+ def _get_known_divisor (n : int , file_path : str = REPO_PATH ) -> Optional [torch .Tensor ]:
82
86
"""
83
- Fetch a known hadamard matrix of size `n` from hadamard repo path if it exists
87
+ Fetch a known hadamard matrix from the given file path. The returned matrix will
88
+ be of of size `k` such that `n` divides `d` and `n / d` is a power of two. Return
89
+ None if no such matrix exists.
84
90
85
91
Note: This function reopens the safetensors file every time it is called.
86
92
This is inefficient, but inconsequential because hadamards are typically
@@ -91,9 +97,10 @@ def _get_known_hadamard(n: int, file_path: str = REPO_PATH) -> Optional[torch.Te
91
97
:return: a known hadamard matrix of size `n` if one exists, else None
92
98
"""
93
99
with safe_open (file_path , framework = "pt" , device = "cpu" ) as file :
94
- for divisor in file .keys ():
95
- if n % int (divisor ) == 0 :
96
- return file .get_tensor (divisor )
100
+ divisors = sorted ([int (key ) for key in file .keys ()], reverse = True )
101
+ for divisor in divisors :
102
+ if n % divisor == 0 and is_pow2 (n // divisor ):
103
+ return file .get_tensor (str (divisor )).to (dtype = DTYPE )
97
104
98
105
return None
99
106
@@ -102,12 +109,11 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
102
109
size = X .shape [- 1 ]
103
110
104
111
# Check if we have the determined hadamard matrix
105
- hadK = _get_known_hadamard (size )
106
- K = hadK .size (0 ) if hadK is not None else 1
107
- if hadK is None and not _is_pow2 (size ):
112
+ hadK = _get_known_divisor (size )
113
+ if hadK is None :
108
114
raise ValueError (f"Cannot construct random hadamard matrix of size { size } " )
115
+ K = hadK .size (0 )
109
116
110
- # For cases when hadK is not predetermined, determine hadamard matrix
111
117
# Reshape diag matrix with randomized -1/+1
112
118
input = X .clone ().view (- 1 , size , 1 )
113
119
output = input .clone ()
@@ -120,21 +126,11 @@ def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
120
126
(input , output ) = (output , input )
121
127
del output
122
128
123
- # K == 1 when hadK is None; this happens when the size dim (n)
124
- # is not comaptible with any of the maintained hadamard matrices
125
-
126
- if K > 1 :
127
- # Do not explicitly repeat - OOM
128
- # input = torch.bmm(
129
- # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
130
- # Use bcast instead
131
-
132
- # for cases when hadK is pre-determined
133
- input = hadK .view (1 , K , K ).to (input ) @ input
129
+ # Do not explicitly repeat - OOM
130
+ # input = torch.bmm(
131
+ # hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
132
+ # Use bcast instead
133
+ input = hadK .view (1 , K , K ).to (input ) @ input
134
134
135
135
# normalize
136
136
return input .view (X .shape )
137
-
138
-
139
- def _is_pow2 (n : int ) -> bool :
140
- return (n & (n - 1 ) == 0 ) and (n > 0 )
0 commit comments