Skip to content

Commit 4b7475b

Browse files
Match server scenario to standalone implementation (#2086)
* Match server scenario to standalone implementation
1 parent c2a0117 commit 4b7475b

File tree

4 files changed

+28
-27
lines changed

4 files changed

+28
-27
lines changed

.github/workflows/build_wheels.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ name: Build loadgen wheels and release them into PYPI
33
on:
44
release:
55
types: [published]
6+
67
push:
78
branches:
89
- master

compliance/nvidia/TEST01/verify_performance.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def main():
9090
test_mode = line.split(": ", 1)[1].strip()
9191
continue
9292
if test_mode == "SingleStream":
93-
if re.match(".*Early stopping (90th|99.9th) percentile estimate", line):
93+
if re.match(
94+
".*Early stopping (90th|99.9th) percentile estimate", line):
9495
test_score = line.split(": ", 1)[1].strip()
9596
continue
9697

language/mixtral-8x7b/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ python -u evaluate-accuracy.py --checkpoint-path [path_to_model_checkpoint] \
250250

251251
## Accuracy Target
252252

253+
**WARNING:** The full accuracy target was only verified with the standalone script. The reference implementation matches in a subset of the dataset, but hasn't been fully confirm.
254+
253255
Reference scores:
254256
Open Orca:
255257
```json

language/mixtral-8x7b/SUT.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def put(self, value):
119119
self.first_token.put((value, self.response_ids[0]))
120120

121121
self.is_first_token = False
122-
return
122+
123123

124124
self.tokens_cache.append(value)
125125

@@ -356,6 +356,7 @@ def __init__(
356356
total_sample_count=24576,
357357
dataset_path=None,
358358
workers=1,
359+
**kwargs,
359360
):
360361

361362
super().__init__(
@@ -408,9 +409,13 @@ def process_queries(self):
408409
if qitem is None:
409410
break
410411

411-
input_ids_tensor = self.data_object.input_ids[qitem.index]
412-
input_masks_tensor = self.data_object.attention_masks[qitem.index]
413-
dataset = self.data_object.dataset_names[qitem.index]
412+
input_dataset = [self.data_object.dataset_names[qitem.index]]
413+
414+
batch_texts = [self.data_object.input_texts[qitem.index]]
415+
batch_ids = self.tokenizer.batch_encode_plus(
416+
batch_texts, return_tensors="pt", padding=True)
417+
batch_ids = batch_ids.to(self.device)
418+
_, length = batch_ids.input_ids.shape
414419

415420
# TODO: This PoC is super slow with significant overhead. Best to
416421
# create a patch to `generate`
@@ -422,32 +427,24 @@ def process_queries(self):
422427
response_ids=[qitem.id],
423428
)
424429

425-
logits_processor = LogitsProcessorList(
426-
[StopAfterSequence(
427-
self.tokenizer.eos_token_id, device=self.device)]
430+
431+
_ = self.model.generate(
432+
**batch_ids,
433+
num_return_sequences=1,
434+
streamer=tokens_streamer,
435+
**gen_kwargs,
428436
)
429-
if dataset == "MBXP":
430-
_ = self.model.generate(
431-
input_ids=input_ids_tensor,
432-
attention_mask=input_masks_tensor,
433-
pad_token_id=self.tokenizer.pad_token_id,
434-
streamer=tokens_streamer,
435-
logits_processor=logits_processor,
436-
**gen_kwargs,
437-
)
438-
else:
439-
_ = self.model.generate(
440-
input_ids=input_ids_tensor,
441-
attention_mask=input_masks_tensor,
442-
pad_token_id=self.tokenizer.pad_token_id,
443-
streamer=tokens_streamer,
444-
**gen_kwargs,
445-
)
446437

447438
output_tokens = tokens_streamer.get_out_tokens()
448-
n_tokens = len(output_tokens)
439+
processed_output = self.data_object.postProcess(
440+
torch.tensor([output_tokens], dtype=torch.int64),
441+
length=0,
442+
query_id_list=[qitem.index],
443+
dataset_list=input_dataset,
444+
)
445+
n_tokens = len(processed_output[0])
449446
response_array = array.array(
450-
"B", np.array(output_tokens, np.int32).tobytes()
447+
"B", np.array(processed_output[0], np.int32).tobytes()
451448
)
452449
bi = response_array.buffer_info()
453450
response = [

0 commit comments

Comments
 (0)