Skip to content

Commit 3c91e5b

Browse files
committed
reduce diff
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 86d504c commit 3c91e5b

File tree

9 files changed

+51
-104
lines changed

9 files changed

+51
-104
lines changed

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
221221

222222
model.apply(
223223
lambda module: initialize_module_for_quantization(
224-
module,
225-
force_zero_point=force_zero_point_init,
224+
module, force_zero_point=force_zero_point_init
226225
)
227226
)
228227

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,6 @@ def _initialize_scale_zero_point(
185185

186186
expected_shape = (observed_shape[-1], 1)
187187

188-
<<<<<<< HEAD
189-
# 3. Identify quantization scale and zp dtype
190-
scale_dtype = module.weight.dtype
191-
=======
192188
elif strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
193189
assert quantization_args.group_size is not None
194190
if len(observed_shape) < 1:
@@ -218,7 +214,6 @@ def _initialize_scale_zero_point(
218214

219215
# 2. Identify quantization scale and zp dtype
220216
scale_dtype = observed_dtype
221-
>>>>>>> fde779c (refactor)
222217

223218
if is_fp4(quantization_args=quantization_args):
224219
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype

src/compressed_tensors/transform/apply.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Dict
16-
1715
import torch
18-
from accelerate.utils import has_offloaded_params
1916
from compressed_tensors import TRANSFORM_CONFIG_NAME
2017
from compressed_tensors.transform import TransformConfig, TransformFactory
2118

@@ -37,35 +34,3 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3734

3835
# attach config to model for compression/serialization
3936
setattr(model, TRANSFORM_CONFIG_NAME, config)
40-
41-
# ensure that tied weight transforms can be serialized without aliases
42-
# In the future, this could be done by transformers or model compressor
43-
# which would make this more robust to changing dispatches after transforms
44-
_tie_offloaded_tensors(model)
45-
46-
47-
def _tie_offloaded_tensors(model: torch.nn.Module):
48-
"""
49-
When accelerate replaces tensors with meta tensors during offloading, the meta
50-
tensors may not be identical, even if the offloaded values are identical.
51-
52-
However, transformers can only serialize correctly if meta tensors are identical
53-
(see transformers#39263).
54-
55-
This function collects all meta tensors which have shared offloaded values and sets
56-
those tensors to be identical so that they can be removed during serialization
57-
58-
:param model: model potentially containing offloaded meta tensors to fix
59-
"""
60-
61-
# ensure that if a location shares an offloaded tensor pointers, that the
62-
# meta tensor is also identical (assigned to the first instance of parameter)
63-
ptr_to_meta: Dict[int, torch.nn.Parameter] = dict()
64-
for module in model.modules():
65-
if has_offloaded_params(module):
66-
for key, _ in module.named_parameters(recurse=False):
67-
offloaded_ptr = module._hf_hook.weights_map[key].data_ptr()
68-
69-
if offloaded_ptr not in ptr_to_meta:
70-
ptr_to_meta[offloaded_ptr] = getattr(module, key)
71-
setattr(module, key, ptr_to_meta[offloaded_ptr])

src/compressed_tensors/transform/factory/base.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from typing import List, Optional
16+
from collections import defaultdict
17+
from typing import List, Optional, Set, Tuple
1718

1819
import torch
1920
import torch.nn.utils.parametrize as P
@@ -56,6 +57,7 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
5657
self.name = name
5758
self.scheme = scheme
5859
self.generator = torch.Generator()
60+
self.transforms = list()
5961
if seed is not None:
6062
self.generator.manual_seed(seed)
6163

@@ -99,6 +101,8 @@ def apply_to_model(self, model: Module, use_tqdm=True):
99101
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
100102
self._apply_to_module(module, arg)
101103

104+
self._update_tied_weights()
105+
102106
def _apply_to_module(self, module: Module, args: TransformArgs):
103107
"""
104108
Create transforms and apply them to the module
@@ -116,6 +120,7 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
116120
# create transform as submodule
117121
transform_name = f"{self.name}_{args.location}"
118122
transform = self.create_transform(module, args)
123+
self.transforms.append(transform)
119124
register_offload_module(module, transform_name, transform)
120125

121126
# register input transformation hook
@@ -160,6 +165,31 @@ def output_hook(_, _input, output):
160165
else:
161166
raise NotImplementedError()
162167

168+
def _update_tied_weights(self):
169+
"""
170+
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
171+
which is used by transformers to detect and remove shared pointers
172+
during saving
173+
"""
174+
# map from data_ptrs to keys
175+
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
176+
for transform in self.transforms:
177+
for name, param in transform.named_parameters(recurse=False):
178+
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
179+
if has_offloaded_params(transform):
180+
param = transform._hf_hook.weights_map[name]
181+
ptr_to_keys[param.data_ptr()].append((transform, name))
182+
183+
# populate `_dynamic_tied_weights_keys` if there is more than one key
184+
# and ensure that they share tensors
185+
for shared_keys in ptr_to_keys.values():
186+
if len(shared_keys) > 1:
187+
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
188+
189+
for transform, name in shared_keys:
190+
transform._dynamic_tied_weights_keys.add(name)
191+
setattr(transform, name, tensor)
192+
163193

164194
class TransformBase(InternalModule, ABC):
165195
"""
@@ -168,7 +198,11 @@ class TransformBase(InternalModule, ABC):
168198

169199
args: TransformArgs
170200
weight: Parameter
171-
_dynamic_tied_weights_keys: List[str] = ["weight"]
201+
_dynamic_tied_weights_keys: Set[str]
202+
203+
def __init__(self):
204+
super().__init__()
205+
self._dynamic_tied_weights_keys = set()
172206

173207
@abstractmethod
174208
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import List, Optional
15+
from typing import Optional
1616

1717
import torch
1818
from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -84,8 +84,6 @@ def _create_permutation(self, weight: Parameter) -> Parameter:
8484

8585

8686
class HadamardTransform(TransformBase):
87-
_dynamic_tied_weights_keys: List[str] = ["weight", "perm"]
88-
8987
def __init__(
9088
self,
9189
weight: Parameter,

src/compressed_tensors/transform/utils/hadamard.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,13 @@ def _fetch_hadamard_divisor(
115115
than forcing callers to manage the file open context
116116
117117
:param n: size of known hadamard matrix
118-
:param dtype: data type to move fetched hadamard to
119-
:param device: device to move fetched hadamard to
120118
:return: a known hadamard matrix of size `n` if one exists, else None
121119
"""
122-
open_device = torch.device("cpu") if device.type == "meta" else device
123-
with safe_open(file_path, framework="pt", device=str(open_device)) as file:
120+
with safe_open(file_path, framework="pt", device=str(device)) as file:
124121
divisors = sorted((int(key) for key in file.keys()), reverse=True)
125122
for divisor in divisors:
126123
if n % divisor == 0 and is_pow2(n // divisor):
127-
return file.get_tensor(str(divisor)).to(dtype=dtype, device=device)
124+
return file.get_tensor(str(divisor)).to(dtype=dtype)
128125

129126
return None
130127

src/compressed_tensors/utils/helpers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@
1313
# limitations under the License.
1414

1515
import contextlib
16+
import warnings
1617
from functools import wraps
1718
from types import MappingProxyType
1819
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, TypeVar
1920

2021
import numpy
2122
import torch
23+
<<<<<<< HEAD
24+
=======
25+
from frozendict import frozendict
26+
>>>>>>> 6672617 (reduce diff)
2227
from transformers import AutoConfig
2328

2429

@@ -194,7 +199,7 @@ def decorator(func: T) -> T:
194199

195200
@wraps(func)
196201
def wrapped(*args, **kwargs):
197-
logger.bind(log_once=True).warning(message)
202+
warnings.warn(message, DeprecationWarning, stacklevel=2)
198203
return func(*args, **kwargs)
199204

200205
return wrapped

tests/test_transform/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020

2121
class TransformableModel(PreTrainedModel):
22-
config_class = PretrainedConfig
23-
2422
def __init__(self, *sizes):
2523
super().__init__(config=PretrainedConfig())
2624
self.fcs = torch.nn.ModuleList(

tests/test_transform/factory/test_serialization.py

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
16-
1715
import pytest
1816
import torch
1917
from compressed_tensors.transform import (
@@ -22,9 +20,7 @@
2220
apply_transform_config,
2321
)
2422
from compressed_tensors.utils import offloaded_dispatch
25-
from safetensors import safe_open
2623
from tests.testing_utils import requires_accelerate, requires_gpu
27-
from transformers import AutoModelForCausalLM, AutoTokenizer
2824

2925

3026
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
@@ -42,57 +38,17 @@ def test_serialization(type, randomize, model_apply, tmp_path, offload=False):
4238
apply_transform_config(model, config)
4339

4440
# save model
45-
model_path = os.path.join(tmp_path, "test_model_path")
46-
model.save_pretrained(model_path)
47-
48-
# check that saved values match model values
49-
# note that shared weights are only serialized once
50-
safetensors_path = os.path.join(model_path, "model.safetensors")
51-
with safe_open(safetensors_path, framework="pt", device="cpu") as file:
52-
saved_keys = set(file.keys())
53-
assert {
54-
"fcs.0.weight",
55-
"fcs.1.weight",
56-
"fcs.2.weight",
57-
"fcs.3.weight",
58-
"fcs.4.weight",
59-
} <= saved_keys
60-
for key in saved_keys:
61-
param = model.get_parameter(key)
62-
saved_param = file.get_tensor(key)
41+
model.save_pretrained(tmp_path)
6342

64-
if param.device.type != "meta": # skip testing values in offload case
65-
assert torch.equal(param, saved_param)
43+
# TODO: reload model
6644

6745

46+
@pytest.mark.skip(reason="Requires changes in upstream transformers")
47+
# https://github.com/huggingface/transformers/pull/39280
48+
# https://github.com/huggingface/transformers/pull/39263
6849
@requires_gpu
6950
@requires_accelerate()
7051
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
7152
@pytest.mark.parametrize("randomize", (True, False))
7253
def test_serialization_offload(type, randomize, model_apply, tmp_path):
7354
test_serialization(type, randomize, model_apply, tmp_path, offload=True)
74-
75-
76-
@pytest.mark.skip("Requires transformers#40673")
77-
@requires_gpu
78-
@pytest.mark.parametrize(
79-
"model_stub,exp_perplexity",
80-
[
81-
("nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16", 10.0),
82-
("nm-testing/Llama-3.2-1B-Instruct-quip-w4a16", 10.0),
83-
],
84-
)
85-
def test_load_perplexity(model_stub, exp_perplexity):
86-
model = AutoModelForCausalLM.from_pretrained(model_stub, device_map="cuda")
87-
tokenizer = AutoTokenizer.from_pretrained(model_stub)
88-
89-
prompt = "The capital of France is Paris, the capital of Germany is Berlin"
90-
inputs = tokenizer(prompt, return_tensors="pt")
91-
inputs = {key: value.to(model.device) for key, value in inputs.items()}
92-
labels = inputs["input_ids"]
93-
94-
with torch.no_grad():
95-
outputs = model(**inputs, labels=labels)
96-
97-
perplexity = torch.exp(outputs.loss)
98-
assert perplexity <= exp_perplexity

0 commit comments

Comments
 (0)