Skip to content

Commit 08ba523

Browse files
committed
high precision hadamard
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 84386fa commit 08ba523

File tree

1 file changed

+16
-3
lines changed
  • src/compressed_tensors/transform/utils

1 file changed

+16
-3
lines changed

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
__all__ = ["get_transform_size", "apply_transform_weight"]
2222

2323

24+
TRANSFORM_PRECISION = torch.float64
25+
26+
2427
def get_transform_size(
2528
module: torch.nn.Module,
2629
location: TransformLocation,
@@ -77,14 +80,24 @@ def apply_transform_weight(
7780
The transform has to account for Linear's transposed weights
7881
:return: value after weight has been applied
7982
"""
83+
# get function used to apply transform
8084
fn, axis = _get_transform_method(module_type, location)
8185

82-
assert weight.shape[0] == weight.shape[1]
86+
# reshape for head_dim
8387
head_dim = weight.shape[0]
8488
num_heads = value.shape[axis] // head_dim
85-
8689
value = value.unflatten(axis, (num_heads, head_dim))
87-
value = fn(weight, value)
90+
91+
# cast to transform precision
92+
value_dtype = value.dtype
93+
94+
# apply transform
95+
value = fn(weight.to(TRANSFORM_PRECISION), value.to(TRANSFORM_PRECISION))
96+
97+
# [undo] cast to transform precision
98+
value = value.to(value_dtype)
99+
100+
# [undo] reshape for head_dim
88101
value = value.flatten(axis - 1, axis)
89102

90103
return value

0 commit comments

Comments
 (0)