Skip to content

Commit d586c39

Browse files
Mbatch p3l (ROCm#401)
* Enabling P3L.py & P3L_mling.py tests to run with multiple batched queries. This alternation adds minimal measurement noise. The underlining testing material is the same, the resulting measurements are comparable to the old (BS=1) testing runs. Signed-off-by: Alexei V. Ivanov <[email protected]> * Making linters happy. Signed-off-by: Alexei V. Ivanov <[email protected]> * Changed the device specification for the 'forced_sample' tensor. The resulting implementation produces identical measurement, and, actually, became faster (3.21s/it vs 3.42s/it with previous commit). Signed-off-by: Alexei V. Ivanov <[email protected]> * Fixing reporting to reflect processed intervals. Signed-off-by: Alexei V. Ivanov <[email protected]> --------- Signed-off-by: Alexei V. Ivanov <[email protected]>
1 parent b43c8d1 commit d586c39

File tree

4 files changed

+97
-51
lines changed

4 files changed

+97
-51
lines changed

benchmarks/P3L.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
)
4141
should result in PPL ~ PPL=3.8968611189957523
4242
43+
Running the script with multiple batches is possible
44+
by specifying the --batch-size parameter.
45+
4346
"""
4447

4548
import argparse
@@ -140,36 +143,55 @@ def main(args: argparse.Namespace):
140143

141144
logger.info(MESSAGE)
142145
print(MESSAGE)
143-
for c in range(my_n_patches):
146+
147+
my_batchsize = args.batch_size
148+
149+
for c in range(0, my_n_patches, my_batchsize):
150+
144151
CONTEXT = []
145152
my_sampl_par.future_context = []
146-
CONTEXT.append(
147-
my_test_enc['input_ids'][c * my_n_samples:c * my_n_samples +
148-
args.context_size])
149-
upper_boundary = min((c + 1) * my_n_samples + args.context_size,
150-
len(my_test_enc['input_ids']))
151-
my_sampl_par.future_context.append(
152-
my_test_enc['input_ids'][c * my_n_samples +
153-
args.context_size:upper_boundary])
154-
my_sampl_par.max_tokens = len(my_sampl_par.future_context[0])
155-
my_sampl_par.cntr = c
153+
my_sampl_par.cntr = []
154+
155+
for b in range(my_batchsize):
156+
if (c + b) < my_n_patches:
157+
upper_boundary = min(
158+
(c + b + 1) * my_n_samples + args.context_size,
159+
len(my_test_enc['input_ids']))
160+
CONTEXT.append(
161+
my_test_enc['input_ids'][(c + b) * my_n_samples:(c + b) *
162+
my_n_samples + args.context_size])
163+
164+
my_sampl_par.future_context.append(
165+
my_test_enc['input_ids'][(c + b) * my_n_samples +
166+
args.context_size:upper_boundary])
167+
168+
my_sampl_par.cntr.append(c + b)
169+
170+
my_sampl_par.max_tokens = max(
171+
len(my_sampl_par.future_context[b]) for b in range(len(CONTEXT)))
172+
156173
LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par)
157-
num_tokens_generated += len(LOGPROBS[0].outputs[0].token_ids)
158-
if (num_tokens_generated < my_n_samples):
174+
for b in range(len(CONTEXT)):
175+
num_tokens_generated += len(LOGPROBS[b].outputs[0].token_ids)
176+
my_ppl -= LOGPROBS[b].outputs[0].cumulative_logprob
177+
178+
if (num_tokens_generated < my_n_samples * len(CONTEXT)):
159179
MESSAGE = (f"Warning: The number of generated tokens is" \
160-
f"less than requested ({num_tokens_generated}" \
161-
f" < {my_n_samples}).")
180+
f"less than requested ({num_tokens_generated}" \
181+
f" < {my_n_samples*len(CONTEXT)}).")
162182
logger.info(MESSAGE)
163183
print(MESSAGE)
164-
my_ppl -= LOGPROBS[0].outputs[0].cumulative_logprob
165-
MESSAGE = (f"Iteration {c+1} of {my_n_patches} Intermediate" \
184+
185+
MESSAGE = (f"Iterations {c+1} through {c+len(CONTEXT)}" \
186+
" of {my_n_patches} Intermediate" \
166187
"Estimates:\n" \
167188
f"\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n" \
168189
f"\tPerplexity_intermediate=" \
169190
f"{math.exp(my_ppl/num_tokens_generated)}")
170191

171192
logger.info(MESSAGE)
172193
print(MESSAGE)
194+
173195
ending_time = datetime.datetime.now()
174196
MESSAGE = (f"Done @ {ending_time} after processing for" \
175197
f" {ending_time-starting_time}" \
@@ -199,12 +221,9 @@ def main(args: argparse.Namespace):
199221
if __name__ == "__main__":
200222
parser = argparse.ArgumentParser(
201223
description='Measure the PPPL (P3L) score of a given model.')
202-
parser.add_argument(
203-
'--data',
204-
type=str,
205-
default='./wikitext/wikitext-2-v1/test-00000-of-00001.parquet')
206224
parser.add_argument('--context-size', type=int, default=4096)
207225
parser.add_argument('--sample-size', type=int, default=512)
226+
parser.add_argument('--batch-size', type=int, default=1)
208227
parser.add_argument('--patch-size', type=int, default=None)
209228
parser.add_argument(
210229
'--output-json',

benchmarks/P3L_mling.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
5353
for the complete set of possible language-scripture choices.
5454
55+
Running the script with multiple batches is possible
56+
by specifying the --batch-size parameter.
5557
5658
"""
5759

@@ -172,36 +174,55 @@ def main(args: argparse.Namespace):
172174

173175
logger.info(MESSAGE)
174176
print(MESSAGE)
175-
for c in range(my_n_patches):
177+
178+
my_batchsize = args.batch_size
179+
180+
for c in range(0, my_n_patches, my_batchsize):
181+
176182
CONTEXT = []
177183
my_sampl_par.future_context = []
178-
CONTEXT.append(
179-
my_test_enc['input_ids'][c * my_n_samples:c * my_n_samples +
180-
args.context_size])
181-
upper_boundary = min((c + 1) * my_n_samples + args.context_size,
182-
len(my_test_enc['input_ids']))
183-
my_sampl_par.future_context.append(
184-
my_test_enc['input_ids'][c * my_n_samples +
185-
args.context_size:upper_boundary])
186-
my_sampl_par.max_tokens = len(my_sampl_par.future_context[0])
187-
my_sampl_par.cntr = c
184+
my_sampl_par.cntr = []
185+
186+
for b in range(my_batchsize):
187+
if (c + b) < my_n_patches:
188+
upper_boundary = min(
189+
(c + b + 1) * my_n_samples + args.context_size,
190+
len(my_test_enc['input_ids']))
191+
CONTEXT.append(
192+
my_test_enc['input_ids'][(c + b) * my_n_samples:(c + b) *
193+
my_n_samples + args.context_size])
194+
195+
my_sampl_par.future_context.append(
196+
my_test_enc['input_ids'][(c + b) * my_n_samples +
197+
args.context_size:upper_boundary])
198+
199+
my_sampl_par.cntr.append(c + b)
200+
201+
my_sampl_par.max_tokens = max(
202+
len(my_sampl_par.future_context[b]) for b in range(len(CONTEXT)))
203+
188204
LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par)
189-
num_tokens_generated += len(LOGPROBS[0].outputs[0].token_ids)
190-
if (num_tokens_generated < my_n_samples):
205+
for b in range(len(CONTEXT)):
206+
num_tokens_generated += len(LOGPROBS[b].outputs[0].token_ids)
207+
my_ppl -= LOGPROBS[b].outputs[0].cumulative_logprob
208+
209+
if (num_tokens_generated < my_n_samples * len(CONTEXT)):
191210
MESSAGE = (f"Warning: The number of generated tokens is" \
192-
f"less than requested ({num_tokens_generated}" \
193-
f" < {my_n_samples}).")
211+
f"less than requested ({num_tokens_generated}" \
212+
f" < {my_n_samples*len(CONTEXT)}).")
194213
logger.info(MESSAGE)
195214
print(MESSAGE)
196-
my_ppl -= LOGPROBS[0].outputs[0].cumulative_logprob
197-
MESSAGE = (f"Iteration {c+1} of {my_n_patches} Intermediate" \
215+
216+
MESSAGE = (f"Iterations {c+1} through {c+len(CONTEXT)}" \
217+
" of {my_n_patches} Intermediate" \
198218
"Estimates:\n" \
199219
f"\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n" \
200220
f"\tPerplexity_intermediate=" \
201221
f"{math.exp(my_ppl/num_tokens_generated)}")
202222

203223
logger.info(MESSAGE)
204224
print(MESSAGE)
225+
205226
ending_time = datetime.datetime.now()
206227
MESSAGE = (f"Done @ {ending_time} after processing for" \
207228
f" {ending_time-starting_time}" \
@@ -237,6 +258,7 @@ def main(args: argparse.Namespace):
237258
default='./wikitext/wikitext-2-v1/test-00000-of-00001.parquet')
238259
parser.add_argument('--context-size', type=int, default=4096)
239260
parser.add_argument('--sample-size', type=int, default=512)
261+
parser.add_argument('--batch-size', type=int, default=1)
240262
parser.add_argument('--patch-size', type=int, default=None)
241263
parser.add_argument('--lang-script', type=str, default="eng_Latn")
242264
parser.add_argument(

vllm/model_executor/layers/sampler.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,6 @@ def _sample_with_torch(
799799
if sampling_type == SamplingType.GREEDY:
800800
greedy_samples = torch.argmax(logprobs[long_sample_indices],
801801
dim=-1)
802-
803802
if sampled_token_ids_tensor is not None:
804803
# Store sampled tokens in output tensor.
805804
sampled_token_ids_tensor[
@@ -842,17 +841,23 @@ def _sample_with_torch(
842841
sampled_token_ids_tensor[long_sample_indices] = \
843842
multinomial_samples[sampling_type].to(torch.long)
844843
elif sampling_type == SamplingType.FORCED:
845-
if (seq_groups[0].sampling_params.future_context is not None):
846-
forced_samples = torch.tensor([
847-
seq_groups[0].sampling_params.future_context[0][min(
848-
len(sampling_metadata.seq_groups[0].seq_data[
849-
sampling_params.cntr].output_token_ids),
850-
len(seq_groups[0].sampling_params.future_context[0]) -
851-
1)]
852-
])
853-
else:
854-
forced_samples = torch.argmax(logprobs[long_sample_indices],
855-
dim=-1)
844+
forced_samples = torch.tensor([], dtype=torch.int32)
845+
for sgidx in range(len(seq_groups)):
846+
if (seq_groups[sgidx].sampling_params.future_context
847+
is not None):
848+
forced_sample = torch.tensor([
849+
seq_groups[sgidx].sampling_params.future_context[sgidx]
850+
[min(
851+
len(sampling_metadata.seq_groups[sgidx].seq_data[
852+
sampling_params.cntr[sgidx]].output_token_ids),
853+
len(seq_groups[sgidx].sampling_params.
854+
future_context[sgidx]) - 1)]
855+
])
856+
else:
857+
forced_sample = torch.argmax(logprobs[long_sample_indices],
858+
dim=-1)
859+
forced_samples = torch.cat([forced_samples, forced_sample])
860+
856861
elif sampling_type == SamplingType.BEAM:
857862
beam_search_logprobs = logprobs[sample_indices]
858863
else:

vllm/sampling_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class SamplingParams(
183183
min_p: float = 0.0
184184
ppl_measurement: bool = False
185185
future_context: Optional[List[int]] = None
186-
cntr: Optional[int] = None
186+
cntr: Optional[List[int]] = None
187187
seed: Optional[int] = None
188188
stop: Optional[Union[str, List[str]]] = None
189189
stop_token_ids: Optional[List[int]] = None

0 commit comments

Comments
 (0)