@@ -419,15 +419,25 @@ class PTQHookRecInOutLMv2(nn.Module):
419419 leave the special handling, e.g. reshape/cat/shuffling...etc, for later
420420 """
421421
422- def __init__ (self , qcfg , name = None , cls2rec = (nn .Conv2d ,), recInOnly = False ):
422+ def __init__ (
423+ self ,
424+ qcfg ,
425+ name = None ,
426+ cls2rec = (nn .Conv2d , nn .Linear ),
427+ recInOnly = False ,
428+ stop_after_rec = False ,
429+ cache_dev = "cuda" ,
430+ ):
423431 super ().__init__ ()
424432 self .name = name
425433 self .qcfg = qcfg
426434 self .cls2rec = cls2rec
427435 self .rec_input_only = recInOnly
428436 self .num_valid_input = - 1
437+ self .stop_after_rec = stop_after_rec
438+ self .cache_dev = cache_dev
429439
430- def __call__ (self , mod , inputs , output ):
440+ def __call__ (self , mod , inputs , * args , ** _kwargs ):
431441 # make sure this module/block's ptqmode is not 'q_out'
432442 submods = [m for m in mod .modules () if isinstance (m , self .cls2rec )]
433443 if any (sm .ptqmode == "q_out" for sm in submods ):
@@ -448,7 +458,7 @@ def __call__(self, mod, inputs, output):
448458 # check available GPU memory, cache on GPU if possible:
449459 GPUmem_available , _GPUmem_total = torch .cuda .mem_get_info ()
450460 # 1 block for SQUAD/BERT 500 batches*12/batch = ~10G
451- if GPUmem_available / 1e9 > 20 :
461+ if self . cache_dev == "cuda" and GPUmem_available / 1e9 > 20 :
452462 cache_device = "cuda"
453463 else :
454464 cache_device = "cpu"
@@ -461,13 +471,15 @@ def __call__(self, mod, inputs, output):
461471 )
462472
463473 # output could be a tuple of a single tensor or simply a tensor ?
464- assert isinstance (output , (torch .Tensor , tuple ))
465- if not self .rec_input_only :
474+ if not self .rec_input_only and "output" in args :
475+ output = args ["output" ]
476+ assert isinstance (output , (torch .Tensor , tuple ))
466477 self .qcfg ["cached_output" ].append (
467478 output [0 ].detach ().to (cache_device )
468479 if isinstance (output , tuple )
469480 else output .detach ().to (cache_device )
470481 )
482+ assert not self .stop_after_rec
471483
472484
473485# this hook is meant for ptq_loss_func == 'fisher_diag' and to temp hold the "Q_out" of the module
@@ -2021,7 +2033,7 @@ def get_blocks(model, model_type=None):
20212033 "llama" : (
20222034 "model.layers" ,
20232035 "model.embed_tokens" ,
2024- None ,
2036+ "model.rotary_emb" ,
20252037 None ,
20262038 "model.norm" ,
20272039 "lm_head" ,
@@ -2111,20 +2123,16 @@ def cache_block0_inputs(
21112123 model , dloader , qcfg , blocks , emb = None , emb_pos = None , emb_ln = None , dev = "cpu"
21122124):
21132125 """
2114- To cache the input to the first transformer block.
2126+ To cache the input to the first transformer block. Basically a "forward_pre_hook"
2127+ NOTE, change caching from tensor to list to allow varying input length, slightly
2128+ increase memeory due to mask and alibi.
21152129 """
21162130 emb = emb .to (dev )
21172131 if emb_pos is not None :
21182132 emb_pos .to (dev )
21192133 if emb_ln is not None :
21202134 emb_ln = emb_ln .to (dev )
21212135 blocks [0 ] = blocks [0 ].to (dev )
2122- # NOTE, change caching from tensor to list to allow varying input length, slightly
2123- # increase memeory due to mask and alibi.
2124- qcfg ["cached_block0_input" ] = []
2125- qcfg ["cache_id" ] = 0
2126- qcfg ["cached_mask" ] = []
2127- qcfg ["cached_alibi" ] = []
21282136 # move block0 to GPU and excuting fwd() until finish block0
21292137 if "fms" in qcfg ["model_type" ]:
21302138 qcfg ["kw_to_cache" ] = {
@@ -2142,9 +2150,16 @@ def cache_block0_inputs(
21422150 }
21432151 blocks [0 ] = RunModule (blocks [0 ], qcfg )
21442152
2153+ # clear up old cache, if exists.
2154+ qcfg ["cached_block0_input" ] = []
2155+ qcfg ["cache_id" ] = 0
2156+ for kw in qcfg ["kw_to_cache" ].values ():
2157+ if kw in qcfg :
2158+ qcfg [kw ] = []
2159+
21452160 if isinstance (dloader , torch .utils .data .DataLoader ):
21462161 pbar = tqdm (
2147- dloader , desc = "Phase 0: PTQ caching block0 input " , total = qcfg ["ptq_nbatch" ]
2162+ dloader , desc = "Phase 0: Caching block0 inputs " , total = qcfg ["ptq_nbatch" ]
21482163 )
21492164 for data_mb , _ in zip (pbar , range (qcfg ["ptq_nbatch" ])):
21502165 try :
@@ -2310,9 +2325,8 @@ def freeze_layers(m, layer_list):
23102325
23112326@torch .no_grad ()
23122327def calibration_llm_1GPU (qcfg , model , dloader ):
2313- """
2314- calibration for large models that can not fit the whole model on 1 GPU.
2315- """
2328+ """Calibration for large models that can not fit on 1 GPU."""
2329+
23162330 model .train ()
23172331 dev = "cuda"
23182332 qcfg ["batch_size" ] = 1
@@ -2365,6 +2379,83 @@ def calibration_llm_1GPU(qcfg, model, dloader):
23652379 logger .info ("All blocks are calibrated" )
23662380
23672381
2382+ @torch .no_grad ()
2383+ def calibration_llm_1GPU_v2 (qcfg , model , dloader ):
2384+ """
2385+ Improved version of Calibration for large language models that can not fit on 1 GPU with new
2386+ (built-in) calibration mechanism.
2387+ NOTE:
2388+ 1. Calibration only, NO update to weights!
2389+ 2. Rely on a alternative "pre fwd hook" to cache all possible inputs.
2390+ 3. As calibration usually cache a small number of data only, no need to move each batch back and
2391+ forth between GPU and CPU.
2392+ """
2393+
2394+ model .train ()
2395+ dev = "cuda"
2396+ qcfg ["batch_size" ] = 1
2397+ qcfg ["dtype" ] = next (iter (model .parameters ())).dtype
2398+ qcfg ["n_samples" ] = min (qcfg ["ptq_nbatch" ], qcfg ["qmodel_calibration_new" ])
2399+
2400+ assert "model_type" in qcfg , "Unknown model type. please check before proceed."
2401+ assert isinstance (
2402+ dloader , torch .utils .data .DataLoader
2403+ ), "Please provide a valid dataloader."
2404+ # --- Phase 0 cache the inputs of the block0---
2405+ model .config .use_cache = False
2406+ blocks , emb , emb_pos , emb_ln , _ , _ = get_blocks (model , qcfg ["model_type" ])
2407+
2408+ cache_block0_inputs (
2409+ model ,
2410+ dloader ,
2411+ qcfg ,
2412+ blocks ,
2413+ emb = emb ,
2414+ emb_pos = emb_pos ,
2415+ emb_ln = emb_ln ,
2416+ dev = "cpu" ,
2417+ )
2418+ logger .info ("Done, caching inputs to block0 for calibration" )
2419+
2420+ # --- Phase 1 --- compute blocks and last linear layer
2421+ pbar = tqdm (
2422+ blocks , desc = "Phase 1: Calibration for each block" , position = 0 , leave = True
2423+ )
2424+ qcfg ["cached_input" ] = [
2425+ inp .clone ().detach ().to (dev ) for inp in qcfg ["cached_block0_input" ]
2426+ ]
2427+ kw_to_use = {
2428+ kw_org : kw_new
2429+ for kw_org , kw_new in qcfg ["kw_to_cache" ].items ()
2430+ if len (qcfg [kw_new ]) == len (qcfg ["cached_input" ])
2431+ }
2432+ for _num_block , m in enumerate (pbar ):
2433+ m .to (dev )
2434+ for i in tqdm (
2435+ range (qcfg ["n_samples" ]), desc = "number of samples" , position = 1 , leave = False
2436+ ):
2437+ if qcfg ["cached_alibi" ]:
2438+ cached_inp_prev_lay = qcfg ["cached_input" ][i ].unsqueeze (0 ).to (dev )
2439+ data_mb = {
2440+ "attention_mask" : qcfg ["cached_mask" ][i ].unsqueeze (0 ).to (dev ),
2441+ "alibi" : qcfg ["cached_alibi" ][i ].unsqueeze (0 ).to (dev ),
2442+ }
2443+ else :
2444+ cached_inp_prev_lay = qcfg ["cached_input" ][i ]
2445+ data_mb = {
2446+ kw_org : move_to (qcfg [kw_new ][i ], dev )
2447+ for kw_org , kw_new in kw_to_use .items ()
2448+ }
2449+
2450+ with patch_torch_bmm (qcfg ):
2451+ qcfg ["cached_input" ][i ] = m (cached_inp_prev_lay , ** data_mb )[0 ]
2452+
2453+ m .cpu ()
2454+ torch .cuda .empty_cache ()
2455+
2456+ logger .info ("All blocks are calibrated" )
2457+
2458+
23682459@torch .no_grad ()
23692460def activation_stats (name , tensor , act_scales ):
23702461 # TODO if 'QBmm' in name: reshape the tensor.
@@ -2498,8 +2589,8 @@ def get_act_scales_1gpu(model, dloader, qcfg):
24982589
24992590 assert "model_type" in qcfg , "Unknown model type. please check before proceed."
25002591 assert (
2501- qcfg ["loader_len" ] = = qcfg ["ptq_nbatch" ]
2502- ), "set batch_size=1 and PTQ samples== Nbatches "
2592+ qcfg ["loader_len" ] > = qcfg ["ptq_nbatch" ]
2593+ ), "Please make sure dataloader has enough data needed for PTQ (ie. check qcfg['ptq_nbatch']). "
25032594 # --- Phase 0 cache the inputs of the block0---
25042595 blocks , emb , emb_pos , emb_ln , _ , _ = get_blocks (model , qcfg ["model_type" ])
25052596 cache_block0_inputs (
0 commit comments