Skip to content

Commit cc1cc68

Browse files
committed
undo transform precision
Signed-off-by: Kyle Sayers <[email protected]>
1 parent e9a200f commit cc1cc68

File tree

4 files changed

+29
-21
lines changed

4 files changed

+29
-21
lines changed

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
1516
from typing import Optional
1617

17-
import math
1818
import torch
1919
from compressed_tensors.transform import TransformArgs, TransformScheme
2020
from compressed_tensors.transform.factory.base import TransformBase, TransformFactory
@@ -104,6 +104,7 @@ def forward(self, value: Tensor) -> Tensor:
104104
if self.args.inverse:
105105
weight = weight.T
106106

107-
return apply_transform_weight(
108-
weight, value, self.args.location, self.module_type
109-
) / self._scale
107+
return (
108+
apply_transform_weight(weight, value, self.args.location, self.module_type)
109+
/ self._scale
110+
)

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
__all__ = ["get_transform_size", "apply_transform_weight"]
2222

2323

24-
TRANSFORM_PRECISION = torch.float64
25-
26-
2724
def get_transform_size(
2825
module: torch.nn.Module,
2926
location: TransformLocation,
@@ -88,14 +85,8 @@ def apply_transform_weight(
8885
num_heads = value.shape[axis] // head_dim
8986
value = value.unflatten(axis, (num_heads, head_dim))
9087

91-
# cast to transform precision
92-
value_dtype = value.dtype
93-
9488
# apply transform
95-
value = fn(weight.to(TRANSFORM_PRECISION), value.to(TRANSFORM_PRECISION))
96-
97-
# [undo] cast to transform precision
98-
value = value.to(value_dtype)
89+
value = fn(weight, value)
9990

10091
# [undo] reshape for head_dim
10192
value = value.flatten(axis - 1, axis)

src/compressed_tensors/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
from .helpers import *
1717
from .internal import *
18+
from .match import *
1819
from .offload import *
1920
from .permutations_24 import *
2021
from .permute import *
2122
from .safetensors_load import *
2223
from .semi_structured_conversions import *
23-
from .match import *

src/compressed_tensors/utils/match.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
1-
from typing import Iterable, Tuple
2-
from collections.abc import Generator
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
314

15+
import logging
416
import re
17+
from collections.abc import Generator
18+
from typing import Iterable, Tuple
19+
520
import torch
6-
import logging
21+
722

823
_LOGGER: logging.Logger = logging.getLogger(__name__)
924

@@ -15,14 +30,14 @@ def match_named_modules(
1530
model: torch.nn.Module,
1631
targets: Iterable[str] = tuple(),
1732
ignore: Iterable[str] = tuple(),
18-
warn_on_fail: bool = True
33+
warn_on_fail: bool = True,
1934
) -> Generator[Tuple[str, torch.nn.Module], None, None]:
2035
unmatched_targets = set(targets)
2136
for name, module in model.named_modules():
2237
for target in targets:
2338
if is_match(name, module, target):
2439
unmatched_targets.remove(target)
25-
40+
2641
if not any(is_match(name, module, ign) for ign in ignore):
2742
yield name, module
2843

@@ -32,6 +47,7 @@ def match_named_modules(
3247
f"Could not match `{target}` in instance of {model.__class__.__name__}"
3348
)
3449

50+
3551
def is_match(name: str, module: torch.nn.Module, target: str) -> bool:
3652
return _match_name(name, target) or _match_class(module, target)
3753

@@ -46,7 +62,7 @@ def _match_name(name: str, target: str) -> bool:
4662
def _match_class(module: torch.nn.Module, target: str) -> bool:
4763
"""
4864
Will never match against a regex pattern since `:` is not allowed in class names
49-
65+
5066
"""
5167
return any(
5268
issubclass(cls, torch.nn.Module) and cls.__name__ == target

0 commit comments

Comments
 (0)