Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class InferArguments(MergeArguments, LmdeployArguments, SglangArguments, VllmArg
metric: Literal['acc', 'rouge'] = None
# for pt engine
max_batch_size: int = 1
cache_impl: str = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type clarity and consistency, the type hint for cache_impl should be Optional[str] since its default value is None. This aligns with its usage and its definition in other parts of the codebase, such as in PtEngine.

Suggested change
cache_impl: str = None
cache_impl: Optional[str] = None


# only for inference
val_dataset_sample: Optional[int] = None
Expand Down
14 changes: 9 additions & 5 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ def __init__(self, args: Optional[Union[List[str], InferArguments]] = None) -> N

if args.infer_backend == 'pt':
model, self.template = prepare_model_template(args)
self.infer_engine = PtEngine.from_model_template(model, self.template, max_batch_size=args.max_batch_size)
self.infer_engine.reranker_use_activation = args.reranker_use_activation
self.infer_engine = PtEngine.from_model_template(
model,
self.template,
max_batch_size=args.max_batch_size,
reranker_use_activation=args.reranker_use_activation,
cache_impl=args.cache_impl)
logger.info(f'model: {self.infer_engine.model}')
else:
self.template = args.get_template(None)
Expand Down Expand Up @@ -64,8 +68,8 @@ def get_infer_engine(args: InferArguments, template=None, **kwargs):
from .infer_engine import PtEngine
infer_engine_cls = PtEngine
kwargs.update(args.get_model_kwargs())
if hasattr(args, 'max_batch_size'):
kwargs.update({'max_batch_size': args.max_batch_size})
kwargs['max_batch_size'] = args.max_batch_size
kwargs['cache_impl'] = args.cache_impl
elif infer_backend == 'vllm':
from .infer_engine import VllmEngine
infer_engine_cls = VllmEngine
Expand Down Expand Up @@ -180,7 +184,7 @@ def _prepare_val_dataset(self) -> HfDataset:
args = self.args
dataset_kwargs = args.get_dataset_kwargs()
if args.cached_dataset or args.cached_val_dataset:
_, val_datasets = get_cached_dataset(self.args)
_, val_datasets = get_cached_dataset(args)
else:
val_datasets = []
if len(args.val_dataset) > 0:
Expand Down
18 changes: 17 additions & 1 deletion swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
model_kwargs: Optional[Dict[str, Any]] = None,
template: Optional[Template] = None,
reranker_use_activation: bool = True,
cache_impl: Optional[str] = None,
**kwargs):
download_model = kwargs.pop('download_model', True)
self.model, self.processor = get_model_tokenizer(
Expand All @@ -80,6 +81,7 @@ def __init__(
model_kwargs=model_kwargs,
**kwargs)
self.reranker_use_activation = reranker_use_activation
self.cache_impl = cache_impl
self.max_batch_size = max_batch_size
if isinstance(adapters, str):
adapters = [adapters]
Expand Down Expand Up @@ -151,11 +153,21 @@ def _add_adapter(self, adapter_path: str, adapter_name: Optional[str] = None) ->
self.model = Swift.from_pretrained(self.model, adapter_path, adapter_name)

@classmethod
def from_model_template(cls, model, template=None, *, max_batch_size: int = 1):
def from_model_template(
cls,
model,
template=None,
*,
max_batch_size: int = 1,
reranker_use_activation: bool = True,
cache_impl: Optional[str] = None,
):
self = super().__new__(cls)
self.model = model
self.processor = template.processor
self.max_batch_size = max_batch_size
self.reranker_use_activation = reranker_use_activation
self.cache_impl = cache_impl
self._post_init(template)
return self

Expand Down Expand Up @@ -233,6 +245,8 @@ def _model_generate(**kwargs):
template.generate(self.model, **kwargs)

generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model)
if self.cache_impl is not None:
generate_kwargs['cache_implementation'] = self.cache_impl
Comment on lines 247 to +249
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code, which prepares generate_kwargs by adding the cache_implementation, is duplicated in the _infer_full method at lines 408-410. To improve maintainability and adhere to the DRY (Don't Repeat Yourself) principle, I recommend extracting this logic into a private helper method within the PtEngine class. This would centralize the logic for preparing these arguments.

thread = Thread(target=_model_generate, kwargs=generate_kwargs)
thread.start()
batch_size = inputs['attention_mask'].shape[0]
Expand Down Expand Up @@ -392,6 +406,8 @@ def _infer_full(self, template: Template, inputs: Dict[str, Any], *, generation_
generate_kwargs['adapter_names'] = adapter_names
num_prompt_tokens = self._get_num_tokens(inputs)
generate_kwargs = template.prepare_generate_kwargs(generate_kwargs, model=self.model)
if self.cache_impl is not None:
generate_kwargs['cache_implementation'] = self.cache_impl
output = dict(template.generate(self.model, **generate_kwargs))
output.pop('past_key_values', None)
batched_generate_ids = output['sequences']
Expand Down