Skip to content

Commit d00675c

Browse files
committed
merge issues
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 684db8b commit d00675c

File tree

2 files changed

+1
-6
lines changed

2 files changed

+1
-6
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,13 @@ def __init__(
8686
perm: Optional[Parameter],
8787
scheme: TransformScheme,
8888
args: TransformArgs,
89-
precision: torch.dtype,
9089
module_type: type[torch.nn.Module],
9190
):
9291
super().__init__()
9392
self.weight = weight
9493
self.perm = perm
9594
self.scheme = scheme
9695
self.args = args
97-
self.precision = precision
9896
self.module_type = module_type
9997
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
10098

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,12 @@ def create_transform(self, module: Module, args: TransformArgs):
5454
size = get_transform_size(module, args.location, self.scheme.head_dim)
5555
dtype = self.scheme.precision
5656
device = get_offloaded_device(module)
57-
precision = self.scheme.precision
5857

5958
weight = self.weights[size, dtype, device]
6059
if args.inverse:
6160
weight = self.inverses[weight]
6261

63-
return RandomMatrixTransform(weight, args, precision, type(module))
62+
return RandomMatrixTransform(weight, self.scheme, args, type(module))
6463

6564
def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter:
6665
# TODO: verify that weight is invertible (has non-zero determinant)
@@ -81,14 +80,12 @@ def __init__(
8180
weight: Tensor,
8281
scheme: TransformScheme,
8382
args: TransformArgs,
84-
precision: torch.dtype,
8583
module_type: type[torch.nn.Module],
8684
):
8785
super().__init__()
8886
self.weight = weight # is an inverse if args.inverse
8987
self.scheme = scheme
9088
self.args = args
91-
self.precision = precision
9289
self.module_type = module_type
9390

9491
def forward(self, value: Tensor) -> Parameter:

0 commit comments

Comments
 (0)