Skip to content

Commit d46f2f7

Browse files
authored
vllm fix sampling params (#625)
* fix * Update src/lighteval/models/vllm/vllm_model.py * Update src/lighteval/models/vllm/vllm_model.py * fix
1 parent 0ae8136 commit d46f2f7

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

src/lighteval/models/vllm/vllm_model.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def __init__(
129129
self.precision = _get_dtype(config.dtype, config=self._config)
130130

131131
self.model_info = ModelInfo(model_name=self.model_name, model_sha=self.model_sha)
132-
self.sampling_params = SamplingParams(**config.generation_parameters.to_vllm_dict())
133132
self.pairwise_tokenization = config.pairwise_tokenization
134133

135134
@property
@@ -139,8 +138,7 @@ def tokenizer(self):
139138
def cleanup(self):
140139
destroy_model_parallel()
141140
if self.model is not None:
142-
del self.model.llm_engine.model_executor.driver_worker
143-
self.model = None
141+
del self.model
144142
gc.collect()
145143
ray.shutdown()
146144
destroy_distributed_environment()
@@ -247,11 +245,7 @@ def greedy_until(
247245
# the case! Because of that we only use batch size of 1
248246
stop_tokens = dataset[0].stop_sequence
249247

250-
max_new_tokens = (
251-
dataset[0].generation_size
252-
if self.sampling_params.max_tokens is None
253-
else self.sampling_params.max_tokens
254-
)
248+
max_new_tokens = self._config.generation_parameters.max_new_tokens or dataset[0].generation_size
255249
returns_logits = dataset[0].use_logits
256250
num_samples = dataset[0].num_samples
257251

@@ -322,7 +316,8 @@ def _generate(
322316
generate: bool = True,
323317
) -> list[GenerativeResponse]:
324318
"""Contains the actual logic of the generation."""
325-
sampling_params = self.sampling_params.clone() or SamplingParams()
319+
sampling_params = SamplingParams(**self._config.generation_parameters.to_vllm_dict())
320+
326321
if generate:
327322
sampling_params.n = num_samples
328323
sampling_params.max_tokens = max_new_tokens

0 commit comments

Comments
 (0)