Skip to content

Commit aa7d21b

Browse files
committed
key inverses by weight
Signed-off-by: Kyle Sayers <[email protected]>
1 parent d77bcef commit aa7d21b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,19 @@ def create_transform(self, module: Module, args: TransformArgs):
5353
dtype = module.weight.dtype
5454
device = get_offloaded_device(module)
5555

56-
if not args.inverse:
57-
weight = self.weights[size, dtype, device]
58-
else:
59-
weight = self.inverses[size, dtype, device]
56+
weight = self.weights[size, dtype, device]
57+
if args.inverse:
58+
weight = self.inverses[weight]
59+
6060
return RandomMatrixTransform(weight, args)
6161

6262
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6363
data = torch.rand((size, size), dtype=dtype, device=device)
6464
return Parameter(data, requires_grad=self.scheme.requires_grad)
6565

66-
def _create_inverse(self, size: int, dtype: dtype, device: device) -> Parameter:
67-
weight = self.weights[size, dtype, device]
68-
return Parameter(high_precision_invert(weight.data), requires_grad=False)
66+
def _create_inverse(self, weight: Parameter) -> Parameter:
67+
data = high_precision_invert(weight.data)
68+
return Parameter(data, requires_grad=False)
6969

7070

7171
class RandomMatrixTransform(TransformBase):

0 commit comments

Comments
 (0)