File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed
examples/gpu/llm/inference Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -210,7 +210,7 @@ def print_mem_usage(msg):
210210if args .benchmark :
211211 print_mem_usage ("pre-from-pretrained" )
212212
213- is_meta_support = model_type not in ["auto" ] and not args . disable_optimize_transformers
213+ is_meta_support = model_type not in ["auto" ]
214214
215215# Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load
216216with deepspeed .OnDevice (dtype = load_dtype , device = "meta" , enabled = is_meta_support ):
@@ -287,6 +287,12 @@ def write_checkpoints_json():
287287if isinstance (model , deepspeed .InferenceEngine ):
288288 model = model .module
289289
290+ # reinitialize some buffers that is empty caused by meta device loading
291+ if args .disable_optimize_transformers :
292+ if model_type == "llama" and isinstance (model , LlamaForCausalLM ):
293+ if hasattr (model .model , "causal_mask" ):
294+ model .model .causal_mask = torch .triu (torch .ones_like (model .model .causal_mask ), diagonal = 1 )
295+
290296if args .num_beams is None :
291297 args .num_beams = 1 if args .greedy else 4
292298
You can’t perform that action at this time.
0 commit comments