Skip to content

Commit f4608ab

Browse files
authored
fix cpu fallback device restore (#2664)
1 parent 25dd8ea commit f4608ab

File tree

2 files changed

+76
-18
lines changed

2 files changed

+76
-18
lines changed

gptqmodel/quantization/gptq.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def mock_hessian_inverse(self, H: torch.Tensor):
254254
identity = torch.eye(H.shape[0], dtype=torch.float32, device=H.device)
255255
return identity, damp
256256

257+
def log_cpu_fallback(self, stage: str, source_device: torch.device) -> None:
258+
"""Explain when a memory-heavy GPTQ step moves from CUDA to CPU."""
259+
260+
log.warn(
261+
"Quantization: Module `%s` -> CUDA OOM during %s on %s; falling back to CPU. "
262+
"Due to this fallback, the calculation may take much longer than normal.",
263+
self.name,
264+
stage,
265+
source_device,
266+
)
267+
257268
def clone_module(self, copy=True, device: torch.device = None):
258269
if not device:
259270
device = self.module.weight.data.device
@@ -886,6 +897,8 @@ def quantize(
886897
start = time.time()
887898

888899
target_device = getattr(self.module, "target_device", None)
900+
result_device = torch.device(self.module.weight.data.device)
901+
cpu_fallback_used = False
889902
from ..utils.fallback import resolve_fallback_strategy, resolve_threshold, should_use_fallback
890903

891904
resolved_strategy = resolve_fallback_strategy(self.fallback)
@@ -971,11 +984,8 @@ def quantize(
971984
if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower():
972985
raise
973986

974-
log.warn(
975-
"Quantization: Module `%s` -> CUDA OOM during Hessian permutation on %s; retrying that module on CPU.",
976-
self.name,
977-
self.H.device,
978-
)
987+
self.log_cpu_fallback("Hessian permutation", self.H.device)
988+
cpu_fallback_used = True
979989
cpu_device = torch.device("cpu")
980990
perm = perm.to(device=cpu_device)
981991
W = W.to(device=cpu_device)[:, perm]
@@ -1002,11 +1012,8 @@ def quantize(
10021012
if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower():
10031013
raise
10041014

1005-
log.warn(
1006-
"Quantization: Module `%s` -> CUDA OOM during act-group Hessian permutation on %s; retrying that module on CPU.",
1007-
self.name,
1008-
self.H.device,
1009-
)
1015+
self.log_cpu_fallback("act-group Hessian permutation", self.H.device)
1016+
cpu_fallback_used = True
10101017
cpu_device = torch.device("cpu")
10111018
final_perm = final_perm.to(device=cpu_device)
10121019
W = W.to(device=cpu_device)[:, final_perm]
@@ -1022,11 +1029,8 @@ def quantize(
10221029

10231030
# Full-attention blocks on very large models can exceed GPU memory during the
10241031
# dense Hessian inverse; finish that module on CPU instead of aborting the run.
1025-
log.warn(
1026-
"Quantization: Module `%s` -> CUDA OOM during Hessian inverse on %s; retrying quantization on CPU.",
1027-
self.name,
1028-
self.H.device,
1029-
)
1032+
self.log_cpu_fallback("Hessian inverse", self.H.device)
1033+
cpu_fallback_used = True
10301034
cpu_device = torch.device("cpu")
10311035
self.H = self.H.to(device=cpu_device)
10321036
W = W.to(device=cpu_device)
@@ -1233,12 +1237,13 @@ def quantize(
12331237
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
12341238

12351239
if self.qcfg.desc_act and use_hessian:
1240+
invperm = invperm.to(device=Q.device)
12361241
Q = Q[:, invperm]
12371242
g_idx = g_idx[invperm]
12381243
del perm, invperm
12391244

12401245
elif self.qcfg.act_group_aware and use_hessian:
1241-
inv_final = invert_perm(final_perm)
1246+
inv_final = invert_perm(final_perm).to(device=Q.device)
12421247
Q = Q[:, inv_final]
12431248
inv_global_perm = invert_perm(global_perm)
12441249
inv_global_perm_list = inv_global_perm.tolist()
@@ -1273,7 +1278,14 @@ def quantize(
12731278
scale = self.truncate_last_dim(scale, valid_cols)
12741279
zero = self.truncate_last_dim(zero, valid_cols)
12751280

1276-
Q = Q.to(device=self.module.weight.data.device, non_blocking=False)
1281+
if cpu_fallback_used and Q.device != result_device:
1282+
log.info(
1283+
"Quantization: Module `%s` -> CPU fallback complete; moving final quantized weights back to %s.",
1284+
self.name,
1285+
result_device,
1286+
)
1287+
1288+
Q = Q.to(device=result_device, non_blocking=False)
12771289

12781290
duration = time.time() - start
12791291

tests/test_gptq.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,53 @@ def _run_batch(idx: int) -> None:
8585
return PathStats(per_batch_seconds=per_batch, total_seconds=total, peak_bytes=peak_bytes, batches_measured=measured)
8686

8787

88+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for CPU fallback regression coverage")
89+
def test_gptq_cpu_hessian_fallback_returns_quantized_weights_to_original_cuda_device(monkeypatch):
90+
device = torch.device("cuda", 0)
91+
torch.cuda.set_device(device)
92+
torch.manual_seed(0)
93+
94+
layer = _make_module(hidden_dim=8, device=device)
95+
qcfg = QuantizeConfig(bits=4, group_size=2, act_group_aware=True)
96+
gptq = GPTQ(layer, qcfg=qcfg)
97+
gptq.quantizer.configure(perchannel=True)
98+
99+
inp = _generate_input(batch_size=1, seq_len=4, hidden_dim=8, device=device)
100+
gptq.add_batch(inp, None)
101+
102+
calls = {"cuda": 0, "cpu": 0}
103+
104+
def _patched_hessian_inverse(self, hessian: torch.Tensor):
105+
if hessian.device.type == "cuda":
106+
calls["cuda"] += 1
107+
raise RuntimeError("CUDA out of memory. simulated for regression test")
108+
109+
calls["cpu"] += 1
110+
identity = torch.eye(hessian.shape[0], dtype=torch.float32, device=hessian.device)
111+
return identity, self.qcfg.damp_percent
112+
113+
monkeypatch.setattr(GPTQ, "hessian_inverse", _patched_hessian_inverse)
114+
log_messages = []
115+
116+
def _capture_warn(message, *args, **kwargs):
117+
log_messages.append(message % args if args else message)
118+
119+
def _capture_info(message, *args, **kwargs):
120+
log_messages.append(message % args if args else message)
121+
122+
monkeypatch.setattr(gptq_mod.log, "warn", _capture_warn)
123+
monkeypatch.setattr(gptq_mod.log, "info", _capture_info)
124+
125+
qweight, _, _, _, *_ = gptq.quantize(blocksize=4)
126+
127+
assert calls == {"cuda": 1, "cpu": 1}
128+
assert qweight.device == device
129+
joined_logs = "\n".join(log_messages)
130+
assert "falling back to CPU" in joined_logs
131+
assert "may take much longer than normal" in joined_logs
132+
assert "moving final quantized weights back" in joined_logs
133+
134+
88135
class TestGPTQAddBatchCPU(ModelTest):
89136
######### test_gptq_add_batch_cpu.py ###########
90137
pytestmark = pytest.mark.skipif(
@@ -331,4 +378,3 @@ def get_random_word(self):
331378
pytest.skip(
332379
f"Streaming event helper subprocess unavailable: rc={result.returncode}, stderr={result.stderr.strip()}"
333380
)
334-

0 commit comments

Comments
 (0)