Skip to content

Commit c1ad0e7

Browse files
authored
add data type specification for avoid some models don't support (#4963)
data type specification for speculative decoding
1 parent b218809 commit c1ad0e7

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

examples/gpu/llm/inference/speculative_decoding_inf.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,27 @@
2727
)
2828
parser.add_argument("--native-transformers", action="store_true", help="using native transformers for speculative decoding")
2929
parser.add_argument("--turn-off-speculative-decoding", action="store_true", help="using origin hf text to generation path")
30+
parser.add_argument(
31+
"--dtype",
32+
type=str,
33+
choices=["float32", "bfloat16", "float16"],
34+
default="float32",
35+
help="please set this parameter according to the model",
36+
)
3037
args = parser.parse_args()
3138

3239
device = "xpu" if torch.xpu.is_available() else "cpu"
40+
amp_dtype = getattr(torch, args.dtype)
3341

3442
print("start memory used total:", round(torch.xpu.memory_reserved() / 1024**3, 3), "GB")
3543

3644
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
3745
inputs = tokenizer("Once upon a time, there existed a little girl, who liked to have adventures.", return_tensors="pt").input_ids.to(device)
3846

39-
model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=torch.float16).to(device)
47+
model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=amp_dtype).to(device)
4048
model = model.to(memory_format=torch.channels_last)
4149

42-
assistant_model = AutoModelForCausalLM.from_pretrained(args.assistant_model_id, torch_dtype=torch.float16).to(device)
50+
assistant_model = AutoModelForCausalLM.from_pretrained(args.assistant_model_id, torch_dtype=amp_dtype).to(device)
4351
assistant_model = assistant_model.to(memory_format=torch.channels_last)
4452

4553
generate_kwargs = dict(do_sample=True, temperature=0.5)

0 commit comments

Comments
 (0)