21
21
22
22
23
23
REPO_PATH = os .path .join (os .path .dirname (__file__ ), "hadamards.safetensors" )
24
- DTYPE = torch .int32
25
24
26
25
27
26
__all__ = ["random_hadamard_matrix" , "deterministic_hadamard_matrix" , "is_pow2" ]
31
30
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
32
31
33
32
34
- def deterministic_hadamard_matrix (size : int ) -> torch .Tensor :
33
+ def deterministic_hadamard_matrix (
34
+ size : int , dtype : torch .dtype = torch .bfloat16
35
+ ) -> torch .Tensor :
35
36
"""
36
37
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
37
38
`n` must be a power of 2.
@@ -44,11 +45,11 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
44
45
if size <= 0 :
45
46
raise ValueError ("Cannot construct deterministic hadamard of size <= 0" )
46
47
47
- log2 = int (math .log (size , 2 ))
48
+ log2 = int (math .log2 (size ))
48
49
if size != 2 ** log2 :
49
50
raise ValueError ("Cannot construct deterministic hadamard of size != 2^n" )
50
51
51
- H = torch .tensor ([[1 ]], dtype = DTYPE )
52
+ H = torch .tensor ([[1 ]], dtype = dtype )
52
53
53
54
# Sylvester's construction
54
55
for _ in range (0 , log2 ):
@@ -58,7 +59,9 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
58
59
59
60
60
61
def random_hadamard_matrix (
61
- size : int , gen : Optional [torch .Generator ] = None
62
+ size : int ,
63
+ dtype : torch .dtype = torch .bfloat16 ,
64
+ gen : Optional [torch .Generator ] = None ,
62
65
) -> torch .Tensor :
63
66
"""
64
67
Produces a randomly generated Hadamard matrix.
@@ -72,7 +75,7 @@ def random_hadamard_matrix(
72
75
:return: randomly generated hadamard matrix
73
76
"""
74
77
# Benefits: support other shapes / non powers of 2, support randomization
75
- Q = torch .randint (low = 0 , high = 2 , size = (size ,), generator = gen , dtype = DTYPE )
78
+ Q = torch .randint (low = 0 , high = 2 , size = (size ,), generator = gen , dtype = dtype )
76
79
Q = Q * 2 - 1
77
80
Q = torch .diag (Q )
78
81
return _matmul_hadU (Q ) / math .sqrt (size )
@@ -82,7 +85,9 @@ def is_pow2(n: int) -> bool:
82
85
return (n & (n - 1 ) == 0 ) and (n > 0 )
83
86
84
87
85
- def _get_known_divisor (n : int , file_path : str = REPO_PATH ) -> Optional [torch .Tensor ]:
88
+ def _get_known_divisor (
89
+ n : int , dtype : torch .dtype , file_path : str = REPO_PATH
90
+ ) -> Optional [torch .Tensor ]:
86
91
"""
87
92
Fetch a known hadamard matrix from the given file path. The returned matrix will
88
93
be of of size `k` such that `n` divides `d` and `n / d` is a power of two. Return
@@ -100,16 +105,17 @@ def _get_known_divisor(n: int, file_path: str = REPO_PATH) -> Optional[torch.Ten
100
105
divisors = sorted ([int (key ) for key in file .keys ()], reverse = True )
101
106
for divisor in divisors :
102
107
if n % divisor == 0 and is_pow2 (n // divisor ):
103
- return file .get_tensor (str (divisor )).to (dtype = DTYPE )
108
+ return file .get_tensor (str (divisor )).to (dtype = dtype )
104
109
105
110
return None
106
111
107
112
108
113
def _matmul_hadU (X : torch .Tensor ) -> torch .Tensor :
109
114
size = X .shape [- 1 ]
115
+ dtype = X .dtype
110
116
111
117
# Check if we have the determined hadamard matrix
112
- hadK = _get_known_divisor (size )
118
+ hadK = _get_known_divisor (size , dtype )
113
119
if hadK is None :
114
120
raise ValueError (f"Cannot construct random hadamard matrix of size { size } " )
115
121
K = hadK .size (0 )
0 commit comments