Skip to content

Commit 073d8d7

Browse files
committed
perform div in 64
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 1a01f0c commit 073d8d7

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(
9393
self.perm = perm
9494
self.args = args
9595
self.module_type = module_type
96-
self._scale = math.sqrt(weight.size(0))
96+
self._scale = torch.tensor(weight.size(0), dtype=torch.float64).sqrt()
9797

9898
def forward(self, value: Tensor) -> Tensor:
9999
weight = self.weight
@@ -104,6 +104,9 @@ def forward(self, value: Tensor) -> Tensor:
104104
if self.args.inverse:
105105
weight = weight.T
106106

107-
return apply_transform_weight(
107+
tmp = apply_transform_weight(
108108
weight, value, self.args.location, self.module_type
109-
) / self._scale
109+
)
110+
111+
112+
return (tmp.to(torch.float64) / self._scale).to(tmp.dtype)

0 commit comments

Comments
 (0)