Skip to content

Commit 1ba6195

Browse files
committed
solidify dtype, add gpu tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent cde1066 commit 1ba6195

File tree

4 files changed

+40
-29
lines changed

4 files changed

+40
-29
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5959
return HadamardTransform(weight, args)
6060

6161
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
62-
data = deterministic_hadamard_matrix(size)
62+
data = deterministic_hadamard_matrix(size, dtype=dtype)
6363
data = data.to(dtype=dtype, device=device)
6464
return Parameter(data, requires_grad=self.scheme.requires_grad)
6565

src/compressed_tensors/transform/factory/random_hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ class RandomHadamardFactory(HadamardFactory):
2929
"""
3030

3131
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
32-
data = random_hadamard_matrix(size, self.generator)
32+
data = random_hadamard_matrix(size, dtype=dtype, gen=self.generator)
3333
data = data.to(dtype=dtype, device=device)
3434
return Parameter(data, requires_grad=self.scheme.requires_grad)

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222

2323
REPO_PATH = os.path.join(os.path.dirname(__file__), "hadamards.safetensors")
24-
DTYPE = torch.int32
2524

2625

2726
__all__ = ["random_hadamard_matrix", "deterministic_hadamard_matrix", "is_pow2"]
@@ -31,7 +30,9 @@
3130
# https://github.com/Dao-AILab/fast-hadamard-transform/tree/master
3231

3332

34-
def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
33+
def deterministic_hadamard_matrix(
34+
size: int, dtype: torch.dtype = torch.bfloat16
35+
) -> torch.Tensor:
3536
"""
3637
Construct an n-by-n Hadamard matrix, using Sylvester's construction.
3738
`n` must be a power of 2.
@@ -44,11 +45,11 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
4445
if size <= 0:
4546
raise ValueError("Cannot construct deterministic hadamard of size <= 0")
4647

47-
log2 = int(math.log(size, 2))
48+
log2 = int(math.log2(size))
4849
if size != 2**log2:
4950
raise ValueError("Cannot construct deterministic hadamard of size != 2^n")
5051

51-
H = torch.tensor([[1]], dtype=DTYPE)
52+
H = torch.tensor([[1]], dtype=dtype)
5253

5354
# Sylvester's construction
5455
for _ in range(0, log2):
@@ -58,7 +59,9 @@ def deterministic_hadamard_matrix(size: int) -> torch.Tensor:
5859

5960

6061
def random_hadamard_matrix(
61-
size: int, gen: Optional[torch.Generator] = None
62+
size: int,
63+
dtype: torch.dtype = torch.bfloat16,
64+
gen: Optional[torch.Generator] = None,
6265
) -> torch.Tensor:
6366
"""
6467
Produces a randomly generated Hadamard matrix.
@@ -72,7 +75,7 @@ def random_hadamard_matrix(
7275
:return: randomly generated hadamard matrix
7376
"""
7477
# Benefits: support other shapes / non powers of 2, support randomization
75-
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=DTYPE)
78+
Q = torch.randint(low=0, high=2, size=(size,), generator=gen, dtype=dtype)
7679
Q = Q * 2 - 1
7780
Q = torch.diag(Q)
7881
return _matmul_hadU(Q) / math.sqrt(size)
@@ -82,7 +85,9 @@ def is_pow2(n: int) -> bool:
8285
return (n & (n - 1) == 0) and (n > 0)
8386

8487

85-
def _get_known_divisor(n: int, file_path: str = REPO_PATH) -> Optional[torch.Tensor]:
88+
def _get_known_divisor(
89+
n: int, dtype: torch.dtype, file_path: str = REPO_PATH
90+
) -> Optional[torch.Tensor]:
8691
"""
8792
Fetch a known hadamard matrix from the given file path. The returned matrix will
8893
be of of size `k` such that `n` divides `d` and `n / d` is a power of two. Return
@@ -100,16 +105,17 @@ def _get_known_divisor(n: int, file_path: str = REPO_PATH) -> Optional[torch.Ten
100105
divisors = sorted([int(key) for key in file.keys()], reverse=True)
101106
for divisor in divisors:
102107
if n % divisor == 0 and is_pow2(n // divisor):
103-
return file.get_tensor(str(divisor)).to(dtype=DTYPE)
108+
return file.get_tensor(str(divisor)).to(dtype=dtype)
104109

105110
return None
106111

107112

108113
def _matmul_hadU(X: torch.Tensor) -> torch.Tensor:
109114
size = X.shape[-1]
115+
dtype = X.dtype
110116

111117
# Check if we have the determined hadamard matrix
112-
hadK = _get_known_divisor(size)
118+
hadK = _get_known_divisor(size, dtype)
113119
if hadK is None:
114120
raise ValueError(f"Cannot construct random hadamard matrix of size {size}")
115121
K = hadK.size(0)

tests/test_transform/utils/test_hadamard.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
is_pow2,
2121
random_hadamard_matrix,
2222
)
23+
from tests.testing_utils import requires_gpu
2324

2425

2526
_sizes_to_test = [
@@ -28,27 +29,29 @@
2829
1280, # qwen_2_5_vl vision
2930
1600, # gpt2 xl
3031
2048, # gpt3 small
31-
# 3584, # qwen_2_5_vl
32-
# 3840, # qwen_2_5_vl vision qkv
33-
# 4096, # llama3
34-
# 14336, # llama3 intermediate
35-
# 18944, # qwen_2_5_vl intermediate
32+
3584, # qwen_2_5_vl
33+
3840, # qwen_2_5_vl vision qkv
34+
4096, # llama3
35+
14336, # llama3 intermediate
36+
18944, # qwen_2_5_vl intermediate
3637
]
3738

3839

40+
@requires_gpu
3941
@pytest.mark.parametrize("size", _sizes_to_test)
4042
def test_random_hadamard_matrix_compliant(size):
4143
# (H / sqrt(n))(H.T / sqrt(n)) == I
42-
had_matrix = random_hadamard_matrix(size)
43-
product = torch.round(had_matrix @ had_matrix.T)
44-
assert torch.allclose(product, torch.eye(size, dtype=product.dtype), atol=1e-5)
44+
with torch.device("cuda"):
45+
had_matrix = random_hadamard_matrix(size)
46+
product = torch.round(had_matrix @ had_matrix.T)
47+
assert torch.allclose(product, torch.eye(size, dtype=product.dtype), atol=1e-5)
4548

4649

4750
def test_random_hadamard_generator():
4851
# check that generation is deterministic with a seed
4952
generator = torch.Generator().manual_seed(42)
50-
one = random_hadamard_matrix(2048, generator)
51-
two = random_hadamard_matrix(2048, generator)
53+
one = random_hadamard_matrix(2048, gen=generator)
54+
two = random_hadamard_matrix(2048, gen=generator)
5255

5356
one_true = torch.tensor(
5457
[
@@ -69,14 +72,16 @@ def test_random_hadamard_generator():
6972
assert torch.all(two[:3, :3].sign() == two_true.sign())
7073

7174

75+
@requires_gpu
7276
@pytest.mark.parametrize("size", _sizes_to_test)
7377
def test_deterministic_hadamard_compliant(size):
74-
if not is_pow2(size):
75-
with pytest.raises(ValueError):
76-
had_matrix = deterministic_hadamard_matrix(size)
77-
return
78+
with torch.device("cuda"):
79+
if not is_pow2(size):
80+
with pytest.raises(ValueError):
81+
had_matrix = deterministic_hadamard_matrix(size)
82+
return
7883

79-
# (H / sqrt(n))(H.T / sqrt(n)) == I
80-
had_matrix = deterministic_hadamard_matrix(size)
81-
product = had_matrix @ had_matrix.T
82-
assert torch.allclose(product, torch.eye(size, dtype=product.dtype), atol=1e-5)
84+
# (H / sqrt(n))(H.T / sqrt(n)) == I
85+
had_matrix = deterministic_hadamard_matrix(size)
86+
product = had_matrix @ had_matrix.T
87+
assert torch.allclose(product, torch.eye(size, dtype=product.dtype), atol=1e-5)

0 commit comments

Comments
 (0)