|
| 1 | +# Copyright (c) Alibaba, Inc. and its affiliates. |
| 2 | +import json |
| 3 | +import torch |
| 4 | +from modelscope import GenerationConfig |
| 5 | + |
| 6 | +from swift.tuners import Swift |
| 7 | +from swift.utils import (get_logger, print_model_info, seed_everything, |
| 8 | + show_layers) |
| 9 | +from ..tuners.rome import RomeConfig |
| 10 | +from .utils import (RomeArguments, Template, get_dataset, get_model_tokenizer, |
| 11 | + get_template, inference) |
| 12 | + |
| 13 | +logger = get_logger() |
| 14 | + |
| 15 | + |
| 16 | +def rome_infer(args: RomeArguments) -> None: |
| 17 | + logger.info(f'args: {args}') |
| 18 | + logger.info( |
| 19 | + 'Rome does not support quantization for now, all quantization args will be ignored.' |
| 20 | + ) |
| 21 | + logger.info(f'device_count: {torch.cuda.device_count()}') |
| 22 | + seed_everything(args.seed) |
| 23 | + |
| 24 | + # ### Loading Model and Tokenizer |
| 25 | + model_kwargs = {'low_cpu_mem_usage': True, 'device_map': 'auto'} |
| 26 | + kwargs = {'use_flash_attn': args.use_flash_attn} |
| 27 | + model, tokenizer = get_model_tokenizer(args.model_type, args.torch_dtype, |
| 28 | + model_kwargs, **kwargs) |
| 29 | + |
| 30 | + with open(args.rome_request_file, 'r') as f: |
| 31 | + request = json.load(f) |
| 32 | + |
| 33 | + rome_type: str = None |
| 34 | + if args.model_type in ('llama2-13b-chat', 'llama2-13b', 'llama-13b-chat', |
| 35 | + 'llama-13b'): |
| 36 | + rome_type = 'llama-13b' |
| 37 | + elif args.model_type in ('llama2-7b-chat', 'llama2-7b', 'llama-7b-chat', |
| 38 | + 'llama-7b'): |
| 39 | + rome_type = 'llama-7b' |
| 40 | + |
| 41 | + config = RomeConfig( |
| 42 | + model_type=rome_type, |
| 43 | + knowledge=request, |
| 44 | + tokenizer=tokenizer, |
| 45 | + ) |
| 46 | + model = Swift.prepare_model(model, config, inference_mode=True) |
| 47 | + |
| 48 | + show_layers(model) |
| 49 | + print_model_info(model) |
| 50 | + |
| 51 | + # ### Inference |
| 52 | + template: Template = get_template(args.template_type, tokenizer, |
| 53 | + args.system, args.max_length) |
| 54 | + generation_config = GenerationConfig( |
| 55 | + max_length=None, |
| 56 | + max_new_tokens=args.max_new_tokens, |
| 57 | + temperature=args.temperature, |
| 58 | + top_k=args.top_k, |
| 59 | + do_sample=args.do_sample, |
| 60 | + repetition_penalty=args.repetition_penalty, |
| 61 | + pad_token_id=tokenizer.pad_token_id, |
| 62 | + eos_token_id=tokenizer.eos_token_id) |
| 63 | + logger.info(f'generation_config: {generation_config}') |
| 64 | + if args.overwrite_generation_config: |
| 65 | + generation_config.save_pretrained(args.ckpt_dir) |
| 66 | + model.generation_config = generation_config |
| 67 | + |
| 68 | + if args.eval_human: |
| 69 | + while True: |
| 70 | + query = input('<<< ') |
| 71 | + data = {'query': query} |
| 72 | + input_ids = template.encode(data)['input_ids'] |
| 73 | + inference(input_ids, model, tokenizer, args.stream) |
| 74 | + else: |
| 75 | + _, val_dataset = get_dataset(args.dataset, args.dataset_test_ratio, |
| 76 | + args.dataset_seed) |
| 77 | + mini_val_dataset = val_dataset.select( |
| 78 | + range(min(args.show_dataset_sample, val_dataset.shape[0]))) |
| 79 | + for data in mini_val_dataset: |
| 80 | + response = data['response'] |
| 81 | + data['response'] = None |
| 82 | + input_ids = template.encode(data)['input_ids'] |
| 83 | + inference(input_ids, model, tokenizer, args.stream) |
| 84 | + print() |
| 85 | + print(f'[LABELS]{response}') |
| 86 | + print('-' * 80) |
| 87 | + # input('next[ENTER]') |
0 commit comments