@@ -1943,14 +1943,12 @@ def __init__(self, module, qcfg):
19431943 def forward (self , inp , ** kwargs ):
19441944 self .qcfg ["cached_block0_input" ].append (inp .cpu ())
19451945 self .qcfg ["cache_id" ] += 1
1946- for k , v in kwargs .items ():
1947- if k == "attention_mask" :
1948- if v is not None :
1949- self .qcfg ["cached_mask" ].append (v .cpu ())
1950- if k == "alibi" :
1951- self .qcfg ["cached_alibi" ].append (v .cpu ())
1952- if k == "position_ids" :
1953- self .qcfg ["position_ids" ].append (v .cpu ())
1946+ for kw_org , kw_qcfg in self .qcfg ["kw_to_cache" ].items ():
1947+ if kw_qcfg not in self .qcfg :
1948+ self .qcfg [kw_qcfg ] = []
1949+ v = kwargs .get (kw_org , None )
1950+ if v is not None :
1951+ self .qcfg [kw_qcfg ].append (move_to (v , "cpu" ))
19541952 raise ValueError
19551953
19561954
@@ -1965,14 +1963,15 @@ def __init__(self, module, qcfg):
19651963 self .module = module
19661964
19671965 def forward (self , ** kwargs ):
1968- for k , v in kwargs .items ():
1969- if k == "x" :
1970- self .qcfg ["cached_block0_input" ][self .qcfg ["cache_id" ]] = v .cpu ()
1971- self .qcfg ["cache_id" ] += 1
1972- if k == "mask" :
1973- self .qcfg ["cached_mask" ] = v .cpu ()
1974- if k == "rel_pos_bias" :
1975- self .qcfg ["cached_pos_bias" ] = v .cpu ()
1966+ self .qcfg ["cached_block0_input" ][self .qcfg ["cache_id" ]] = kwargs ["x" ].cpu ()
1967+ self .qcfg ["cache_id" ] += 1
1968+ for kw_org , kw_qcfg in self .qcfg ["kw_to_cache" ]:
1969+ if kw_qcfg not in self .qcfg :
1970+ self .qcfg [kw_qcfg ] = []
1971+ v = kwargs .get (kw_org , None )
1972+ if v is not None :
1973+ self .qcfg [kw_qcfg ].append (v .cpu ())
1974+
19761975 raise ValueError
19771976
19781977
@@ -2126,13 +2125,21 @@ def cache_block0_inputs(
21262125 qcfg ["cache_id" ] = 0
21272126 qcfg ["cached_mask" ] = []
21282127 qcfg ["cached_alibi" ] = []
2129- qcfg [
2130- "position_ids"
2131- ] = [] # latest transformers requires pos_ids to be fed into fwd()
21322128 # move block0 to GPU and excuting fwd() until finish block0
21332129 if "fms" in qcfg ["model_type" ]:
2130+ qcfg ["kw_to_cache" ] = {
2131+ "mask" : "cached_mask" ,
2132+ "rel_pos_bias" : "cached_pos_bias" ,
2133+ }
21342134 blocks [0 ] = RunFMModule (blocks [0 ], qcfg )
21352135 else :
2136+ # latest transformers requires pos_ids to be fed into fwd()
2137+ qcfg ["kw_to_cache" ] = {
2138+ "attention_mask" : "cached_mask" ,
2139+ "alibi" : "cached_alibi" ,
2140+ "position_ids" : "position_ids" ,
2141+ "position_embeddings" : "position_embeddings" ,
2142+ }
21362143 blocks [0 ] = RunModule (blocks [0 ], qcfg )
21372144
21382145 if isinstance (dloader , torch .utils .data .DataLoader ):
@@ -2464,12 +2471,13 @@ def get_module_act_scales(m, block_idx, qcfg, act_scales):
24642471 alibi = qcfg ["cached_alibi" ][i ].unsqueeze (0 ).to (dev ),
24652472 )[0 ].cpu ()
24662473 else :
2474+ kwargs = {
2475+ kw_org : move_to (qcfg [kw_qcfg ][i ], dev ) if qcfg [kw_qcfg ] != [] else None
2476+ for kw_org , kw_qcfg in qcfg ["kw_to_cache" ].items ()
2477+ }
24672478 qcfg ["cached_input" ][i ] = m (
24682479 qcfg ["cached_input" ][i ].to (dev ),
2469- attention_mask = None
2470- if qcfg ["cached_mask" ] == []
2471- else qcfg ["cached_mask" ][i ].to (dev ),
2472- position_ids = qcfg ["position_ids" ][i ].to (dev ),
2480+ ** kwargs ,
24732481 )[0 ].cpu ()
24742482 for h in hooks :
24752483 h .remove ()
@@ -2482,7 +2490,7 @@ def get_act_scales_1gpu(model, dloader, qcfg):
24822490 """
24832491 get activation blocks on 1gpu for very large models that cannot fit in 1gpu
24842492 """
2485- dev = "cuda:0 "
2493+ dev = "cuda"
24862494 qcfg ["batch_size" ] = 1
24872495 qcfg ["loader_len" ] = len (dloader )
24882496 qcfg ["dtype" ] = next (iter (model .parameters ())).dtype
0 commit comments