@@ -129,7 +129,6 @@ def __init__(
129
129
self .precision = _get_dtype (config .dtype , config = self ._config )
130
130
131
131
self .model_info = ModelInfo (model_name = self .model_name , model_sha = self .model_sha )
132
- self .sampling_params = SamplingParams (** config .generation_parameters .to_vllm_dict ())
133
132
self .pairwise_tokenization = config .pairwise_tokenization
134
133
135
134
@property
@@ -139,8 +138,7 @@ def tokenizer(self):
139
138
def cleanup (self ):
140
139
destroy_model_parallel ()
141
140
if self .model is not None :
142
- del self .model .llm_engine .model_executor .driver_worker
143
- self .model = None
141
+ del self .model
144
142
gc .collect ()
145
143
ray .shutdown ()
146
144
destroy_distributed_environment ()
@@ -247,11 +245,7 @@ def greedy_until(
247
245
# the case! Because of that we only use batch size of 1
248
246
stop_tokens = dataset [0 ].stop_sequence
249
247
250
- max_new_tokens = (
251
- dataset [0 ].generation_size
252
- if self .sampling_params .max_tokens is None
253
- else self .sampling_params .max_tokens
254
- )
248
+ max_new_tokens = self ._config .generation_parameters .max_new_tokens or dataset [0 ].generation_size
255
249
returns_logits = dataset [0 ].use_logits
256
250
num_samples = dataset [0 ].num_samples
257
251
@@ -322,7 +316,8 @@ def _generate(
322
316
generate : bool = True ,
323
317
) -> list [GenerativeResponse ]:
324
318
"""Contains the actual logic of the generation."""
325
- sampling_params = self .sampling_params .clone () or SamplingParams ()
319
+ sampling_params = SamplingParams (** self ._config .generation_parameters .to_vllm_dict ())
320
+
326
321
if generate :
327
322
sampling_params .n = num_samples
328
323
sampling_params .max_tokens = max_new_tokens
0 commit comments