|
27 | 27 | ) |
28 | 28 | parser.add_argument("--native-transformers", action="store_true", help="using native transformers for speculative decoding") |
29 | 29 | parser.add_argument("--turn-off-speculative-decoding", action="store_true", help="using origin hf text to generation path") |
| 30 | +parser.add_argument( |
| 31 | + "--dtype", |
| 32 | + type=str, |
| 33 | + choices=["float32", "bfloat16", "float16"], |
| 34 | + default="float32", |
| 35 | + help="please set this parameter according to the model", |
| 36 | +) |
30 | 37 | args = parser.parse_args() |
31 | 38 |
|
32 | 39 | device = "xpu" if torch.xpu.is_available() else "cpu" |
| 40 | +amp_dtype = getattr(torch, args.dtype) |
33 | 41 |
|
34 | 42 | print("start memory used total:", round(torch.xpu.memory_reserved() / 1024**3, 3), "GB") |
35 | 43 |
|
36 | 44 | tokenizer = AutoTokenizer.from_pretrained(args.model_id) |
37 | 45 | inputs = tokenizer("Once upon a time, there existed a little girl, who liked to have adventures.", return_tensors="pt").input_ids.to(device) |
38 | 46 |
|
39 | | -model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=torch.float16).to(device) |
| 47 | +model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=amp_dtype).to(device) |
40 | 48 | model = model.to(memory_format=torch.channels_last) |
41 | 49 |
|
42 | | -assistant_model = AutoModelForCausalLM.from_pretrained(args.assistant_model_id, torch_dtype=torch.float16).to(device) |
| 50 | +assistant_model = AutoModelForCausalLM.from_pretrained(args.assistant_model_id, torch_dtype=amp_dtype).to(device) |
43 | 51 | assistant_model = assistant_model.to(memory_format=torch.channels_last) |
44 | 52 |
|
45 | 53 | generate_kwargs = dict(do_sample=True, temperature=0.5) |
|
0 commit comments