We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1a01f0c commit 073d8d7Copy full SHA for 073d8d7
src/compressed_tensors/transform/factory/hadamard.py
@@ -93,7 +93,7 @@ def __init__(
93
self.perm = perm
94
self.args = args
95
self.module_type = module_type
96
- self._scale = math.sqrt(weight.size(0))
+ self._scale = torch.tensor(weight.size(0), dtype=torch.float64).sqrt()
97
98
def forward(self, value: Tensor) -> Tensor:
99
weight = self.weight
@@ -104,6 +104,9 @@ def forward(self, value: Tensor) -> Tensor:
104
if self.args.inverse:
105
weight = weight.T
106
107
- return apply_transform_weight(
+ tmp = apply_transform_weight(
108
weight, value, self.args.location, self.module_type
109
- ) / self._scale
+ )
110
+
111
112
+ return (tmp.to(torch.float64) / self._scale).to(tmp.dtype)
0 commit comments