Skip to content

Commit 58257eb

Browse files
merge main
2 parents 9e2b530 + 0731aa5 commit 58257eb

File tree

6 files changed

+20
-12
lines changed

6 files changed

+20
-12
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def from_compression_config(
147147

148148
sparsity_config = cls.parse_sparsity_config(compression_config)
149149
quantization_config = cls.parse_quantization_config(compression_config)
150-
# NOTE: transfrom config is not support by ctconfig yet
150+
# TODO: transform config is not support by CompressedTensorsConfig yet
151151

152152
if sparsity_config is None and quantization_config is None:
153153
return None
@@ -207,7 +207,7 @@ def from_pretrained_model(
207207

208208
@staticmethod
209209
def parse_sparsity_config(
210-
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
210+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
211211
) -> Union[Dict[str, Any], None]:
212212
"""
213213
Parse sparsity config from quantization/compression config. Sparsity
@@ -227,7 +227,7 @@ def parse_sparsity_config(
227227

228228
@staticmethod
229229
def parse_quantization_config(
230-
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
230+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
231231
) -> Union[Dict[str, Any], None]:
232232
"""
233233
Parse quantization config from quantization/compression config. The

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
self.args = args
9696
self.module_type = module_type
9797
self._scale = torch.tensor(weight.size(0), dtype=self.scheme.precision).sqrt()
98+
self._precision = scheme.precision if args.is_online() else torch.float64
9899

99100
def forward(self, value: Tensor) -> Tensor:
100101
weight = self.weight
@@ -107,8 +108,8 @@ def forward(self, value: Tensor) -> Tensor:
107108

108109
return (
109110
apply_transform_weight(
110-
weight.to(self.scheme.precision),
111-
value.to(self.scheme.precision),
111+
weight.to(self._precision),
112+
value.to(self._precision),
112113
self.args.location,
113114
self.module_type,
114115
)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,21 @@ def __init__(
8787
self.scheme = scheme
8888
self.args = args
8989
self.module_type = module_type
90+
self._precision = scheme.precision if args.is_online() else torch.float64
9091

9192
def forward(self, value: Tensor) -> Parameter:
9293
return apply_transform_weight(
93-
self.weight.to(self.scheme.precision),
94-
value.to(self.scheme.precision),
94+
self.weight.to(self._precision),
95+
value.to(self._precision),
9596
self.args.location,
9697
self.module_type,
9798
).to(value.dtype)
9899

99100
def right_inverse(self, value: Tensor) -> Tensor:
100101
inverse = high_precision_invert(self.weight)
101102
return apply_transform_weight(
102-
inverse.to(self.scheme.precision),
103-
value.to(self.scheme.precision),
103+
inverse.to(self._precision),
104+
value.to(self._precision),
104105
self.args.location,
105106
self.module_type,
106107
).to(value.dtype)

src/compressed_tensors/transform/transform_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,9 @@ def wrap_singleton(cls, value):
6868
if isinstance(value, str):
6969
return [value]
7070
return value
71+
72+
def is_online(self) -> bool:
73+
return self.location not in (
74+
TransformLocation.WEIGHT_INPUT,
75+
TransformLocation.WEIGHT_OUTPUT,
76+
)

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ class TransformScheme(BaseModel):
3636
:param randomize: True if uniquely randomized transform weights should be used,
3737
otherwise use identical transform weights where applicable
3838
:param requires_grad: True if weights include gradients for training
39-
:param precision: Precision at which this transform should be applied. This applies
40-
to both weight fusing and online rotations
39+
:param precision: Precision at which this transform should be applied during online
40+
rotations. Fused (offline) rotations are always performed in float64
4141
"""
4242

4343
type: str

src/compressed_tensors/utils/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def validate_from_str(name: str) -> torch.dtype:
3636
try:
3737
value = getattr(torch, name)
3838
assert isinstance(value, torch.dtype)
39-
except AttributeError:
39+
except Exception:
4040
raise ValueError(f"No such torch dtype `torch.{name}`")
4141

4242
return value

0 commit comments

Comments
 (0)