Skip to content

Commit 4f975b7

Browse files
author
Zhenhuan Chen
authored
fix llama's tp acc problem without breaking meta device loading (#4699)
1 parent 81a0f96 commit 4f975b7

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

examples/gpu/llm/inference/run_generation_with_deepspeed.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def print_mem_usage(msg):
210210
if args.benchmark:
211211
print_mem_usage("pre-from-pretrained")
212212

213-
is_meta_support = model_type not in ["auto"] and not args.disable_optimize_transformers
213+
is_meta_support = model_type not in ["auto"]
214214

215215
# Construct model with fake meta tensors, later will be replaced during ds-inference ckpt load
216216
with deepspeed.OnDevice(dtype=load_dtype, device="meta", enabled=is_meta_support):
@@ -287,6 +287,12 @@ def write_checkpoints_json():
287287
if isinstance(model, deepspeed.InferenceEngine):
288288
model = model.module
289289

290+
# reinitialize some buffers that is empty caused by meta device loading
291+
if args.disable_optimize_transformers:
292+
if model_type == "llama" and isinstance(model, LlamaForCausalLM):
293+
if hasattr(model.model, "causal_mask"):
294+
model.model.causal_mask = torch.triu(torch.ones_like(model.model.causal_mask), diagonal=1)
295+
290296
if args.num_beams is None:
291297
args.num_beams = 1 if args.greedy else 4
292298

0 commit comments

Comments
 (0)