|
| 1 | +# Copyright (c) Alibaba, Inc. and its affiliates. |
| 2 | +import os |
| 3 | + |
| 4 | +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
| 5 | +import torch |
| 6 | +from transformers import BitsAndBytesConfig, GenerationConfig, TextStreamer |
| 7 | +from utils import (InferArguments, get_dataset, get_model_tokenizer, |
| 8 | + get_preprocess) |
| 9 | +from llm_infer import llm_infer |
| 10 | +from swift import Swift, get_logger |
| 11 | +from swift.tuners import LoRA |
| 12 | +from swift.utils import inference, parse_args, seed_everything |
| 13 | + |
| 14 | +logger = get_logger() |
| 15 | + |
| 16 | + |
| 17 | +def merge_lora(args: InferArguments) -> None: |
| 18 | + assert args.sft_type == 'lora' |
| 19 | + args.init_argument() |
| 20 | + logger.info(f'device_count: {torch.cuda.device_count()}') |
| 21 | + |
| 22 | + # ### Loading Model and Tokenizer |
| 23 | + model, tokenizer = get_model_tokenizer( |
| 24 | + args.model_type, torch_dtype=args.torch_dtype, device_map='cpu') |
| 25 | + |
| 26 | + # ### Preparing LoRA |
| 27 | + model = Swift.from_pretrained(model, args.ckpt_dir, inference_mode=True) |
| 28 | + if not hasattr(model, 'peft_type'): |
| 29 | + LoRA.unpatch_lora(model, model.adapters['default'].config, 'default') |
| 30 | + else: |
| 31 | + model.merge_and_unload() |
| 32 | + |
| 33 | + new_ckpt_dir = os.path.abspath( |
| 34 | + os.path.join(args.ckpt_dir, '..', 'output_ckpt')) |
| 35 | + logger.info(f'new_ckpt_dir: `{new_ckpt_dir}`') |
| 36 | + logger.info("Setting args.sft_type: 'full'") |
| 37 | + logger.info(f'Setting args.ckpt_dir: {new_ckpt_dir}') |
| 38 | + args.ckpt_dir = new_ckpt_dir |
| 39 | + args.sft_type = 'full' |
| 40 | + if not os.path.exists(args.ckpt_dir): |
| 41 | + model.model.save_pretrained(args.ckpt_dir) |
| 42 | + tokenizer.save_pretrained(args.ckpt_dir) |
| 43 | + |
| 44 | + |
| 45 | +if __name__ == '__main__': |
| 46 | + args, remaining_argv = parse_args(InferArguments) |
| 47 | + if len(remaining_argv) > 0: |
| 48 | + if args.ignore_args_error: |
| 49 | + logger.warning(f'remaining_argv: {remaining_argv}') |
| 50 | + else: |
| 51 | + raise ValueError(f'remaining_argv: {remaining_argv}') |
| 52 | + merge_lora(args) |
| 53 | + llm_infer(args) |
0 commit comments