diff --git a/examples/deepseek/ptq.py b/examples/deepseek/ptq.py index 06407121..c213eb27 100644 --- a/examples/deepseek/ptq.py +++ b/examples/deepseek/ptq.py @@ -229,6 +229,15 @@ def load_deepseek_model(model_config: str, model_path: str, batch_size: int): return model +def end_process(): + world_size = int(os.getenv("WORLD_SIZE", "1")) + rank = int(os.getenv("RANK", "0")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + if world_size > 1: + if dist.is_initialized(): + dist.destroy_process_group() + + def ptq( model, tokenizer, @@ -364,3 +373,4 @@ def state_dict_filter(state_dict): ) model = ptq(model, tokenizer, args.quant_cfg, args.batch_size, args.calib_size) save_amax_and_quant_config(model, args.output_path, not args.disable_fp8_kvcache) + end_process()