@@ -254,6 +254,17 @@ def mock_hessian_inverse(self, H: torch.Tensor):
254254 identity = torch .eye (H .shape [0 ], dtype = torch .float32 , device = H .device )
255255 return identity , damp
256256
257+ def log_cpu_fallback (self , stage : str , source_device : torch .device ) -> None :
258+ """Explain when a memory-heavy GPTQ step moves from CUDA to CPU."""
259+
260+ log .warn (
261+ "Quantization: Module `%s` -> CUDA OOM during %s on %s; falling back to CPU. "
262+ "Due to this fallback, the calculation may take much longer than normal." ,
263+ self .name ,
264+ stage ,
265+ source_device ,
266+ )
267+
257268 def clone_module (self , copy = True , device : torch .device = None ):
258269 if not device :
259270 device = self .module .weight .data .device
@@ -886,6 +897,8 @@ def quantize(
886897 start = time .time ()
887898
888899 target_device = getattr (self .module , "target_device" , None )
900+ result_device = torch .device (self .module .weight .data .device )
901+ cpu_fallback_used = False
889902 from ..utils .fallback import resolve_fallback_strategy , resolve_threshold , should_use_fallback
890903
891904 resolved_strategy = resolve_fallback_strategy (self .fallback )
@@ -971,11 +984,8 @@ def quantize(
971984 if self .H .device .type != "cuda" or "out of memory" not in str (exc ).lower ():
972985 raise
973986
974- log .warn (
975- "Quantization: Module `%s` -> CUDA OOM during Hessian permutation on %s; retrying that module on CPU." ,
976- self .name ,
977- self .H .device ,
978- )
987+ self .log_cpu_fallback ("Hessian permutation" , self .H .device )
988+ cpu_fallback_used = True
979989 cpu_device = torch .device ("cpu" )
980990 perm = perm .to (device = cpu_device )
981991 W = W .to (device = cpu_device )[:, perm ]
@@ -1002,11 +1012,8 @@ def quantize(
10021012 if self .H .device .type != "cuda" or "out of memory" not in str (exc ).lower ():
10031013 raise
10041014
1005- log .warn (
1006- "Quantization: Module `%s` -> CUDA OOM during act-group Hessian permutation on %s; retrying that module on CPU." ,
1007- self .name ,
1008- self .H .device ,
1009- )
1015+ self .log_cpu_fallback ("act-group Hessian permutation" , self .H .device )
1016+ cpu_fallback_used = True
10101017 cpu_device = torch .device ("cpu" )
10111018 final_perm = final_perm .to (device = cpu_device )
10121019 W = W .to (device = cpu_device )[:, final_perm ]
@@ -1022,11 +1029,8 @@ def quantize(
10221029
10231030 # Full-attention blocks on very large models can exceed GPU memory during the
10241031 # dense Hessian inverse; finish that module on CPU instead of aborting the run.
1025- log .warn (
1026- "Quantization: Module `%s` -> CUDA OOM during Hessian inverse on %s; retrying quantization on CPU." ,
1027- self .name ,
1028- self .H .device ,
1029- )
1032+ self .log_cpu_fallback ("Hessian inverse" , self .H .device )
1033+ cpu_fallback_used = True
10301034 cpu_device = torch .device ("cpu" )
10311035 self .H = self .H .to (device = cpu_device )
10321036 W = W .to (device = cpu_device )
@@ -1233,12 +1237,13 @@ def quantize(
12331237 g_idx = torch .tensor (g_idx , dtype = torch .int32 , device = Q .device )
12341238
12351239 if self .qcfg .desc_act and use_hessian :
1240+ invperm = invperm .to (device = Q .device )
12361241 Q = Q [:, invperm ]
12371242 g_idx = g_idx [invperm ]
12381243 del perm , invperm
12391244
12401245 elif self .qcfg .act_group_aware and use_hessian :
1241- inv_final = invert_perm (final_perm )
1246+ inv_final = invert_perm (final_perm ). to ( device = Q . device )
12421247 Q = Q [:, inv_final ]
12431248 inv_global_perm = invert_perm (global_perm )
12441249 inv_global_perm_list = inv_global_perm .tolist ()
@@ -1273,7 +1278,14 @@ def quantize(
12731278 scale = self .truncate_last_dim (scale , valid_cols )
12741279 zero = self .truncate_last_dim (zero , valid_cols )
12751280
1276- Q = Q .to (device = self .module .weight .data .device , non_blocking = False )
1281+ if cpu_fallback_used and Q .device != result_device :
1282+ log .info (
1283+ "Quantization: Module `%s` -> CPU fallback complete; moving final quantized weights back to %s." ,
1284+ self .name ,
1285+ result_device ,
1286+ )
1287+
1288+ Q = Q .to (device = result_device , non_blocking = False )
12771289
12781290 duration = time .time () - start
12791291
0 commit comments