Skip to content

Commit e98372d

Browse files
authored
Disable AMP for Baichuan fp16 inference on single tile mode. (#4992) (#4996)
1 parent be45f16 commit e98372d

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

examples/gpu/llm/inference/run_benchmark.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ Run_benchmark_baichuan2-13b-chat() {
162162
sub_model_name=baichuan2-13b
163163
dir=perf/${model}/beam${beam}_bs${bs}_input${input}_out${out}
164164
mkdir -p ${dir}
165-
python -u run_generation.py --benchmark -m ${model} --sub-model-name ${sub_model_name} --use-static-cache --num-beams ${beam} --num-iter ${iter} --batch-size ${bs} --input-tokens ${input} --max-new-tokens ${out} --device xpu --ipex --dtype float16 --token-latency 2>&1 | tee log_e2e
165+
python -u run_generation.py --benchmark -m ${model} --sub-model-name ${sub_model_name} --use-static-cache --num-beams ${beam} --num-iter ${iter} --batch-size ${bs} --input-tokens ${input} --max-new-tokens ${out} --device xpu --ipex --dtype float16 --token-latency --disable-auto-cast 2>&1 | tee log_e2e
166166
mv log_e2e ${dir}
167167
PROFILE=1 python -u run_generation.py --benchmark -m ${model} --sub-model-name ${sub_model_name} --use-static-cache --num-beams ${beam} --num-iter ${iter} --batch-size ${bs} --input-tokens ${input} --max-new-tokens ${out} --device xpu --ipex --dtype float16
168168
mv profile*pt ${dir}

examples/gpu/llm/inference/run_generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
parser.add_argument("--acc-iter", default=-1, type=int)
116116
parser.add_argument("--use-static-cache", default=False, action="store_true", help="use static kv cache")
117117
parser.add_argument("--use-hf-code", default=True, action="store_false", help="use hf transformers code")
118+
parser.add_argument("--disable-auto-cast", default=False, action="store_true", help="whether to disable auto-mixed-precision feature")
118119
args = parser.parse_args()
119120
print(args)
120121

@@ -144,7 +145,7 @@ def get_memory_usage(name, args):
144145
# torch._C._jit_set_texpr_fuser_enabled(False)
145146

146147
# dtype
147-
amp_enabled = True if args.dtype != "float32" else False
148+
amp_enabled = True if args.dtype != "float32" and not args.disable_auto_cast else False
148149
amp_dtype = getattr(torch, args.dtype)
149150

150151
# load model

0 commit comments

Comments
 (0)