Skip to content

Commit a5b8fed

Browse files
patrickhywUbuntuPatrick WangUbuntu
authored
Small QOL changes (#7)
* fixed transformers version * installing small dataset directly * now loading dataset properly * fixed to make datagen work on gpu * readme typo * rem print in get datastore data * added fschat back to req * Revert "added fschat back to req" This reverts commit b8366db. * add fs chat to req without removing data/ * main.rs * modified some prints * rm warnings * wrote code to compute total three/four grams and write to a file * moved stuff to exp * removed extra stuff * reverted get_datastore_chat.py * gitignore * lib.rs comments --------- Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Patrick Wang <[email protected]> Co-authored-by: Ubuntu <[email protected]>
1 parent 5b119c1 commit a5b8fed

File tree

6 files changed

+22
-9
lines changed

6 files changed

+22
-9
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Generated files
2+
*.idx
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

DraftRetriever/src/lib.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
// The code for retrival is adapted from https://github.com/Intsights/PySubstringSearch;
22
// The code for drafft buffer is adapted from https://github.com/FasterDecoding/Medusa/blob/main/medusa/model/utils.py#L31-L124
33
use ahash::AHashSet;
4-
use bstr::io::BufReadExt;
54
use byteorder::{ReadBytesExt, WriteBytesExt, ByteOrder, LittleEndian};
6-
use memchr::memmem;
75
use parking_lot::Mutex;
86
use pyo3::exceptions;
97
use pyo3::prelude::*;
@@ -21,7 +19,6 @@ use pyo3::types::PyList;
2119
use std::collections::BinaryHeap;
2220
use std::fs;
2321
use std::io::Cursor;
24-
use std::fs::OpenOptions;
2522

2623
extern "C" {
2724
pub fn libsais_int(
@@ -217,6 +214,7 @@ impl Reader {
217214
long: Option<i32>,
218215
) -> PyResult<(Vec<Vec<i32>>, Vec<Vec<i32>>, Vec<i32>, Vec<i32>, Vec<Vec<i32>>)> {
219216

217+
// substring_i32 is just a rust version of py_substring
220218
let mut substring_i32 = Vec::new();
221219
for item in py_substring.iter() {
222220
let num: i32 = item.extract()?;
@@ -225,19 +223,28 @@ impl Reader {
225223

226224
let results = Arc::new(Mutex::new(Vec::new()));
227225

226+
// each sub index is a buffer/suffix pair
228227
self.sub_indexes.par_iter_mut().for_each(
229228
|sub_index| {
230229
let mut start_of_indices = None;
231230
let mut end_of_indices = None;
232231

232+
// since suffix arrays have the suffixes in sorted order, we do a binary search
233+
// over the suffix array
234+
// this binary search finds the start of the matching suffixes
233235
let mut left_anchor = sub_index.suffixes_file_start;
234236
let mut right_anchor = sub_index.suffixes_file_end - 4;
235237
while left_anchor <= right_anchor {
236238
let middle_anchor = left_anchor + ((right_anchor - left_anchor) / 4 / 2 * 4);
237239
sub_index.index_file.seek(SeekFrom::Start(middle_anchor as u64)).unwrap();
240+
// data_index is the value at middle_anchor in the suffix array
238241
let data_index = sub_index.index_file.read_i32::<LittleEndian>().unwrap();
242+
// line is the actual suffix
239243
let line = &sub_index.data[(data_index) as usize..];
240244

245+
// we don't use the entire suffix. we look for suffixes that start with the substring we're looking for
246+
// the suffix array sorts suffixes based on the start of the suffix, so this technique is sound
247+
// the "match length" is defined by the length of substring_i32. the suffix array doesn't need to worry about "match length"
241248
if line.starts_with(&substring_i32) {
242249
start_of_indices = Some(middle_anchor);
243250
right_anchor = middle_anchor - 4;
@@ -253,6 +260,7 @@ impl Reader {
253260
return;
254261
}
255262

263+
// this binary search finds the end of the matching suffixes
256264
let mut right_anchor = sub_index.suffixes_file_end - 4;
257265
while left_anchor <= right_anchor {
258266
let middle_anchor = left_anchor + ((right_anchor - left_anchor) / 4 / 2 * 4);

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ python3 get_datastore_code.py --model-path codellama/CodeLlama-7b-instruct-hf --
9393
### Inference on MT-Bench
9494
```bash
9595
cd llm_judge
96-
RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python3 get_model_answer_rest.py --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 --datastore-path ../datastore/datastore_chat_small.idx
96+
RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python3 gen_model_answer_rest.py --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 --datastore-path ../datastore/datastore_chat_small.idx
9797
```
9898

9999
### Inference on HumanEval

datastore/get_datastore_chat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,3 @@
5252
writer.add_entry(token_list)
5353

5454
writer.finalize()
55-

llm_judge/gen_model_answer_rest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def get_model_answers(
269269
if conv.name == "xgen" and output.startswith("Assistant:"):
270270
output = output.replace("Assistant:", "", 1).strip()
271271
except RuntimeError as e:
272-
print("ERROR question ID: ", question["question_id"])
272+
print(f"question ID {question['question_id']} errored out with {e}")
273273
output = "ERROR"
274274

275275
turns.append(output)

requirements.txt

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
torch
2-
"fschat[model_worker,webui]"
2+
fschat[model_worker,webui]
33
maturin==0.12
44
numpy==1.26.1
55
tqdm==4.66.1
6-
transformers==4.34.1
6+
transformers
77
accelerate==0.24.1
88
datasets
99
openai
10-
anthropic
10+
anthropic
11+
sentencepiece
12+
protobuf
13+
shortuuid

0 commit comments

Comments
 (0)