Skip to content

Commit 85419e2

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/transform_save
2 parents 4085613 + 5478b43 commit 85419e2

File tree

31 files changed

+1328
-173
lines changed

31 files changed

+1328
-173
lines changed

.github/actions/test/action.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ runs:
6969
echo "::endgroup::"
7070
7171
if [[ "${ENABLE_COVERAGE}" == "true" ]]; then
72-
echo "::group::consolidating coverage reports"
73-
mkdir -p coverage-results
74-
mv .coverage coverage-results/ || echo ".coverage file not found"
75-
mv coverage-html coverage-results/ || echo "coverage-html folder not found"
76-
mv coverage.json coverage-results/ || echo "coverage.json file not found"
72+
echo "::group::check coverage reports"
73+
if [ ! -d coverage-html ]; then
74+
echo "ERROR: coverage-html folder not found"
75+
exit 1
76+
fi
7777
echo "::endgroup::"
7878
fi
7979

.github/workflows/build-test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ on:
2525

2626
# test related parameters
2727
test_configs:
28-
description: "python, label, timeout"
28+
description: "python, label, timeout, etc"
2929
type: string
3030
required: true
3131

@@ -53,6 +53,7 @@ jobs:
5353
python: ${{ matrix.test_config.python }}
5454
timeout: ${{ matrix.test_config.timeout }}
5555
whl: ${{ needs.BUILD.outputs.whl }}
56+
code_coverage: ${{ matrix.test_config.code_coverage || false }}
5657
secrets: inherit
5758

5859
UPLOAD:

.github/workflows/test.yml

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ jobs:
7070
permissions:
7171
contents: 'read'
7272
id-token: 'write'
73+
pages: 'write'
74+
environment:
75+
name: github-pages
76+
url: ${{ steps.coverage.outputs.page_url }}
7377

7478
steps:
7579

@@ -134,6 +138,11 @@ jobs:
134138
suitename: test-${{ inputs.python }}-${{ inputs.test_label }}
135139
code_coverage: ${{ inputs.code_coverage }}
136140

141+
- name: extra info for summary
142+
if: ${{ inputs.code_coverage }}
143+
run: |
144+
echo "EXTRA='Code Coverage: https://neuralmagic.github.io/compressed-tensors/'" >> $GITHUB_ENV
145+
137146
- name: summary
138147
uses: neuralmagic/nm-actions/actions/[email protected]
139148
if: success() || failure()
@@ -143,6 +152,7 @@ jobs:
143152
python: ${{ inputs.python }}
144153
whl: ${{ inputs.whl }}
145154
test_status: ${{ steps.test.outputs.status }}
155+
extra: ${{ env.EXTRA }}
146156

147157
- name: copy results to GCP
148158
run: |
@@ -157,9 +167,13 @@ jobs:
157167
retention-days: 5
158168

159169
- name: upload coverage report
160-
uses: actions/upload-artifact@v4
161-
if: (success() || failure()) && inputs.code_coverage
170+
uses: actions/upload-pages-artifact@v3
171+
if: ${{ inputs.code_coverage }}
162172
with:
163-
name: coverage-results
164-
path: coverage-results/*
173+
path: coverage-html
165174
retention-days: 5
175+
176+
- name: deploy to Github Pages
177+
id: coverage
178+
uses: actions/deploy-pages@v4
179+
if: ${{ inputs.code_coverage }}

.github/workflows/trigger-all.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
wf_category: ${{ inputs.wf_category || 'NIGHTLY' }}
3333
gitref: ${{ inputs.gitref || 'main' }}
3434
push_to_pypi: ${{ (github.event.schedule == '30 0 * * *') || inputs.push_to_pypi || false }}
35-
test_configs: '[{"python":"3.11.4","label":"ubuntu-24.04","timeout":"40"},
35+
test_configs: '[{"python":"3.11.4","label":"ubuntu-24.04","timeout":"40","code_coverage":true},
3636
{"python":"3.10.12","label":"ubuntu-22.04","timeout":"40"},
3737
{"python":"3.9.17","label":"k8s-h100-solo","timeout":"40"},
3838
{"python":"3.12.6","label":"k8s-a100-duo","timeout":"40"}]'

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -392,15 +392,18 @@ def compress_model(self, model: Module):
392392
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
393393

394394
if prefix in module_to_scheme or prefix in sparse_compression_targets:
395-
module_device = get_execution_device(module).type
396-
is_meta = (module_device == "meta")
395+
module_device = get_execution_device(module)
396+
is_meta = module_device.type == "meta"
397397

398398
exec_device = "meta" if is_meta else "cpu"
399399
onloading_device = "meta" if is_meta else module_device
400400

401401
# in the future, support compression on same device
402402
with align_module_device(module, execution_device=exec_device):
403-
state_dict = module.state_dict(prefix=f"{prefix}.")
403+
state_dict = {
404+
f"{prefix}.{name}": param
405+
for name, param in module.named_parameters(recurse=False)
406+
}
404407

405408
# quantization first
406409
if prefix in module_to_scheme:
@@ -421,7 +424,7 @@ def compress_model(self, model: Module):
421424

422425
# remove any existing parameters
423426
offload_device = get_offloaded_device(module)
424-
for name, _ in list(module.named_parameters()):
427+
for name, _ in list(module.named_parameters(recurse=False)):
425428
delete_offload_parameter(module, name)
426429

427430
# replace with compressed parameters
@@ -458,7 +461,10 @@ def decompress_model(self, model: Module):
458461
if prefix in module_to_scheme or prefix in sparse_compression_targets:
459462
# in the future, support decompression on same device
460463
with align_module_device(module, execution_device="cpu"):
461-
state_dict = module.state_dict(prefix=f"{prefix}.")
464+
state_dict = {
465+
f"{prefix}.{name}": param
466+
for name, param in module.named_parameters(recurse=False)
467+
}
462468

463469
# sparsity first
464470
if prefix in sparse_compression_targets:
@@ -483,7 +489,7 @@ def decompress_model(self, model: Module):
483489
# remove any existing parameters
484490
exec_device = get_execution_device(module)
485491
offload_device = get_offloaded_device(module)
486-
for name, _ in list(module.named_parameters()):
492+
for name, _ in list(module.named_parameters(recurse=False)):
487493
delete_offload_parameter(module, name)
488494

489495
# replace with decompressed parameters
@@ -747,12 +753,16 @@ def _replace_weights(self, dense_weight_generator, model: Module):
747753

748754
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
749755
"""
750-
Returns a dictionary which maps quantized module names to their quantization schemes
756+
Returns a dictionary which maps quantized module names to their quantization
757+
schemes. Only includes modules with weight quantization
751758
"""
752759
return {
753760
fix_fsdp_module_name(name): module.quantization_scheme
754761
for name, module in model.named_modules()
755-
if is_module_quantized(module)
762+
if (
763+
hasattr(module, "quantization_scheme")
764+
and module.quantization_scheme.weights is not None
765+
)
756766
}
757767

758768

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,13 @@ def sparse24_bitmask_compress(
178178

179179
if tensor.is_meta:
180180
num_rows, num_cols = tensor.shape
181-
compressed_values = torch.empty((num_rows, num_cols // 2), dtype=tensor.dtype, device="meta")
181+
compressed_values = torch.empty(
182+
(num_rows, num_cols // 2), dtype=tensor.dtype, device="meta"
183+
)
182184
packed_cols = (num_cols + 7) // 8
183-
bitmasks_packed = torch.empty((num_rows, packed_cols), dtype=torch.uint8, device="meta")
185+
bitmasks_packed = torch.empty(
186+
(num_rows, packed_cols), dtype=torch.uint8, device="meta"
187+
)
184188
return compressed_values, bitmasks_packed
185189

186190
bytemasks = get_24_bytemasks(tensor=tensor)

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,22 @@ def dequantize(
111111
elif scale.ndim == 2:
112112
if scale.shape[1] == 1:
113113
args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
114-
else:
114+
# Scale height matches input or is 1 -> group quantization across columns
115+
#
116+
# Example 1: scale.shape[0] == 1
117+
# x_q: (4, 8), scale: (1, 4) -> 2 columns per group
118+
#
119+
# Example 2: scale.shape[0] == x_q.shape[0]
120+
# x_q: (4, 8), scale: (4, 4) -> 2 elements per group (per row)
121+
elif (scale.shape[0] == 1) or (scale.shape[0] == x_q.shape[0]):
115122
group_size = int(x_q.shape[1] / scale.shape[1])
116123
args = QuantizationArgs(
117124
strategy=QuantizationStrategy.GROUP, group_size=group_size
118125
)
126+
else:
127+
args = QuantizationArgs(
128+
strategy=QuantizationStrategy.BLOCK, block_structure=scale.shape
129+
)
119130
else:
120131
raise ValueError(
121132
f"Could not infer a quantization strategy from scale with {scale.ndim} "
@@ -189,7 +200,63 @@ def _process_quantization(
189200
q_min, q_max = calculate_range(args, x.device)
190201
group_size = args.group_size
191202

192-
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
203+
# blockwise FP8: quantize per 2D block, supports block_structure for static block quant
204+
if args.strategy == QuantizationStrategy.BLOCK:
205+
original_shape = x.shape
206+
rows, cols = x.shape[-2], x.shape[-1]
207+
block_height, block_width = args.block_structure
208+
209+
# Ensure exact division (tensor dimensions must be divisible by block size)
210+
if rows % block_height != 0:
211+
raise ValueError(
212+
f"Tensor height {rows} is not divisible by block_height {block_height}. "
213+
f"Block quantization requires exact division."
214+
)
215+
if cols % block_width != 0:
216+
raise ValueError(
217+
f"Tensor width {cols} is not divisible by block_width {block_width}. "
218+
f"Block quantization requires exact division."
219+
)
220+
221+
# reshape into blocks and transpose to make each block contiguous
222+
num_rows_blocks = rows // block_height
223+
num_cols_blocks = cols // block_width
224+
x_blocks = x.reshape(
225+
num_rows_blocks,
226+
block_height,
227+
num_cols_blocks,
228+
block_width,
229+
).transpose(1, 2)
230+
231+
# expand scale/zero_point for blocks
232+
sb = scale.unsqueeze(-1).unsqueeze(-1)
233+
zb = zero_point.unsqueeze(-1).unsqueeze(-1) if zero_point is not None else None
234+
if do_quantize:
235+
# quantize blocks
236+
x_blocks = _quantize(
237+
x=x_blocks,
238+
scale=sb,
239+
zero_point=zb,
240+
q_min=q_min,
241+
q_max=q_max,
242+
args=args,
243+
dtype=dtype,
244+
global_scale=global_scale,
245+
)
246+
if do_dequantize:
247+
# dequantize blocks
248+
x_blocks = _dequantize(
249+
x_q=x_blocks,
250+
scale=sb,
251+
zero_point=zb,
252+
global_scale=global_scale,
253+
)
254+
# restore original shape
255+
output = x_blocks.transpose(1, 2).reshape(original_shape)
256+
elif args.strategy in (
257+
QuantizationStrategy.GROUP,
258+
QuantizationStrategy.TENSOR_GROUP,
259+
):
193260
n_dims = x.shape
194261
if len(n_dims) > 2:
195262
x = x.squeeze(0)

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import math
18+
import warnings
1819
from enum import Enum
1920
from typing import List, Optional
2021

@@ -172,14 +173,43 @@ def _initialize_scale_zero_point(
172173

173174
if base_name == "weight" and weight_shape is not None:
174175
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
175-
# (output_channels, 1)
176+
# (output_channels, 1) - only for weights
176177
expected_shape = (weight_shape[0], 1)
177178
elif quantization_args.strategy in (
178179
QuantizationStrategy.TENSOR_GROUP,
179180
QuantizationStrategy.GROUP,
180181
):
182+
# GROUP/TENSOR_GROUP for both weights and activations
181183
num_groups = math.ceil(weight_shape[1] / quantization_args.group_size)
182184
expected_shape = (weight_shape[0], max(num_groups, 1))
185+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
186+
# For block quantization, scale shape should match number of blocks - only for weights
187+
if quantization_args.block_structure is None:
188+
raise ValueError(
189+
"Block quantization requires block_structure to be specified"
190+
)
191+
block_height, block_width = quantization_args.block_structure
192+
rows, cols = weight_shape[-2], weight_shape[-1]
193+
num_rows_blocks = math.ceil(rows / block_height)
194+
num_cols_blocks = math.ceil(cols / block_width)
195+
196+
# Warn if dimensions don't divide evenly
197+
if rows % block_height != 0 or cols % block_width != 0:
198+
warnings.warn(
199+
f"Block quantization: tensor shape {weight_shape} does not divide evenly "
200+
f"by block structure {quantization_args.block_structure}. "
201+
f"Some blocks will be incomplete which may affect quantization quality.",
202+
UserWarning,
203+
)
204+
205+
expected_shape = (num_rows_blocks, num_cols_blocks)
206+
elif quantization_args.strategy == QuantizationStrategy.BLOCK:
207+
warnings.warn(
208+
f"BLOCK quantization not supported for {base_name} activations. "
209+
f"Falling back to tensor-level quantization.",
210+
UserWarning,
211+
)
212+
expected_shape = 1
183213

184214
# 3. Identify quantization scale and zp dtype
185215
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
@@ -189,7 +219,12 @@ def _initialize_scale_zero_point(
189219
else:
190220
# TODO: consider erroring out in the future as if the dtype if not one of these,
191221
# there is likely bug
192-
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32, torch.float64]:
222+
if scale_dtype not in [
223+
torch.float16,
224+
torch.bfloat16,
225+
torch.float32,
226+
torch.float64,
227+
]:
193228
scale_dtype = torch.float16
194229
zp_dtype = quantization_args.pytorch_dtype()
195230

0 commit comments

Comments
 (0)