diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index 660b2dd385..ad8a3f108a 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -156,6 +156,7 @@ class InferArguments(MergeArguments, LmdeployArguments, SglangArguments, VllmArg metric: Literal['acc', 'rouge'] = None # for pt engine max_batch_size: int = 1 + cache_impl: str = None # only for inference val_dataset_sample: Optional[int] = None diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 6747038384..f1063ca53e 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -32,8 +32,12 @@ def __init__(self, args: Optional[Union[List[str], InferArguments]] = None) -> N if args.infer_backend == 'pt': model, self.template = prepare_model_template(args) - self.infer_engine = PtEngine.from_model_template(model, self.template, max_batch_size=args.max_batch_size) - self.infer_engine.reranker_use_activation = args.reranker_use_activation + self.infer_engine = PtEngine.from_model_template( + model, + self.template, + max_batch_size=args.max_batch_size, + reranker_use_activation=args.reranker_use_activation, + cache_impl=args.cache_impl) logger.info(f'model: {self.infer_engine.model}') else: self.template = args.get_template(None) @@ -64,8 +68,8 @@ def get_infer_engine(args: InferArguments, template=None, **kwargs): from .infer_engine import PtEngine infer_engine_cls = PtEngine kwargs.update(args.get_model_kwargs()) - if hasattr(args, 'max_batch_size'): - kwargs.update({'max_batch_size': args.max_batch_size}) + kwargs['max_batch_size'] = args.max_batch_size + kwargs['cache_impl'] = args.cache_impl elif infer_backend == 'vllm': from .infer_engine import VllmEngine infer_engine_cls = VllmEngine @@ -180,7 +184,7 @@ def _prepare_val_dataset(self) -> HfDataset: args = self.args dataset_kwargs = args.get_dataset_kwargs() if args.cached_dataset or args.cached_val_dataset: - _, val_datasets = get_cached_dataset(self.args) + _, val_datasets = get_cached_dataset(args) else: val_datasets = [] if len(args.val_dataset) > 0: diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 03d73dca63..f981436bc5 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -62,6 +62,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, template: Optional[Template] = None, reranker_use_activation: bool = True, + cache_impl: Optional[str] = None, **kwargs): download_model = kwargs.pop('download_model', True) self.model, self.processor = get_model_tokenizer( @@ -80,6 +81,7 @@ def __init__( model_kwargs=model_kwargs, **kwargs) self.reranker_use_activation = reranker_use_activation + self.cache_impl = cache_impl self.max_batch_size = max_batch_size if isinstance(adapters, str): adapters = [adapters] @@ -151,11 +153,21 @@ def _add_adapter(self, adapter_path: str, adapter_name: Optional[str] = None) -> self.model = Swift.from_pretrained(self.model, adapter_path, adapter_name) @classmethod - def from_model_template(cls, model, template=None, *, max_batch_size: int = 1): + def from_model_template( + cls, + model, + template=None, + *, + max_batch_size: int = 1, + reranker_use_activation: bool = True, + cache_impl: Optional[str] = None, + ): self = super().__new__(cls) self.model = model self.processor = template.processor self.max_batch_size = max_batch_size + self.reranker_use_activation = reranker_use_activation + self.cache_impl = cache_impl self._post_init(template) return self @@ -233,6 +245,8 @@ def _model_generate(**kwargs): template.generate(self.model, **kwargs) generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model) + if self.cache_impl is not None: + generate_kwargs['cache_implementation'] = self.cache_impl thread = Thread(target=_model_generate, kwargs=generate_kwargs) thread.start() batch_size = inputs['attention_mask'].shape[0] @@ -392,6 +406,8 @@ def _infer_full(self, template: Template, inputs: Dict[str, Any], *, generation_ generate_kwargs['adapter_names'] = adapter_names num_prompt_tokens = self._get_num_tokens(inputs) generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model) + if self.cache_impl is not None: + generate_kwargs['cache_implementation'] = self.cache_impl output = dict(template.generate(self.model, **generate_kwargs)) output.pop('past_key_values', None) batched_generate_ids = output['sequences']