Skip to content

Commit ac75468

Browse files
committed
data generation for self-distillation
1 parent 0ac14da commit ac75468

File tree

3 files changed

+253
-0
lines changed

3 files changed

+253
-0
lines changed

data_generation/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Generate chat data for self-distillation
2+
We use vLLM to enable batched generation. First, install dependencies:
3+
```bash
4+
pip install vllm openai
5+
```
6+
7+
## Start server
8+
9+
```bash
10+
python -m vllm.entrypoints.openai.api_server \
11+
--model YOUR_MODEL_NAME --port 8000
12+
```
13+
You can also start multiple servers with different ports to enable parallel generation. In `generate.py`, we scan the ports from 8000 to 8009 to find available servers. You can modify the code to use other ports.
14+
15+
## Generate data
16+
The following command will let the model to continue the first prompt from each sample in `DATA_PATH`, this is suitable for models that can play both roles in a conversation (e.g., Zephyr 7B). If you want to use all prompts in each sample to repeatly talk to the model, use `--chat` instead. `--chat` mode works for more models but may take longer time to generate due to repeated computation (welcome to contribute a better implementation).
17+
18+
```bash
19+
python generate.py --data_path YOUR_DATA_PATH --output_path YOUR_OUTPUT_PATH --num_threads NUM_THREADS --max_tokens YOUR_MAX_TOKENS --temperature YOUR_TEMPERATURE
20+
```
21+
22+
## (Optional) Format data
23+
When generated with `--chat`, the output file will follow the ShareGPT format ([example](https://github.com/lm-sys/FastChat/blob/main/data/dummy_conversation.json)).
24+
You can use the following command to convert the generated text withour `--chat` to the same format:
25+
```bash
26+
python convert_to_sharegpt.py --input_path YOUR_INPUT_PATH --model_name YOUR_MODEL_NAME --output_path YOUR_OUTPUT_PATH
27+
```
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import json
2+
import os
3+
import time
4+
import concurrent.futures
5+
6+
import openai
7+
import shortuuid
8+
import tqdm
9+
10+
import argparse
11+
import random
12+
13+
from tenacity import (
14+
retry,
15+
stop_after_attempt,
16+
wait_random_exponential,
17+
)
18+
19+
from fastchat.conversation import Conversation, SeparatorStyle
20+
from fastchat.model.model_adapter import get_conversation_template
21+
from transformers import AutoTokenizer
22+
23+
# Use the same arguments as in generate.py
24+
parser = argparse.ArgumentParser()
25+
parser.add_argument("--input_path", type=str)
26+
parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-beta")
27+
args = parser.parse_args()
28+
29+
conv = get_conversation_template(args.model_name)
30+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
31+
32+
data = []
33+
with open(args.input_path) as f:
34+
for line in f.readlines():
35+
data.append(json.loads(line))
36+
37+
def convert(text):
38+
messages = []
39+
40+
for turn in text.split(conv.roles[0]):
41+
pairs = turn.split(conv.roles[1])
42+
if len(pairs) != 2:
43+
continue
44+
messages.append({
45+
"from": "human",
46+
"value": pairs[0].split(conv.sep)[0].strip()
47+
})
48+
messages.append({
49+
"from": "gpt",
50+
"value": pairs[1].split(conv.sep)[0].strip()
51+
})
52+
# pop the last message because it might be incomplete
53+
if len(messages) > 0:
54+
messages.pop()
55+
# make sure number of messages is even
56+
if len(messages) % 2 == 1:
57+
messages.pop()
58+
return {"conversations": messages}
59+
60+
sharegpt_data = []
61+
for d in tqdm.tqdm(data):
62+
sample = convert(d["text"])
63+
if len(sample["conversations"]) < 2:
64+
continue
65+
sharegpt_data.append(sample)
66+
67+
# dump to jsonl
68+
with open(args.input_path.replace(".jsonl", "_sharegpt.jsonl"), "w") as f:
69+
for d in sharegpt_data:
70+
f.write(json.dumps(d) + "\n")

data_generation/generate.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import json
2+
import os
3+
import time
4+
import concurrent.futures
5+
6+
import openai
7+
import shortuuid
8+
import tqdm
9+
10+
import argparse
11+
import random
12+
13+
from tenacity import (
14+
retry,
15+
stop_after_attempt,
16+
wait_random_exponential,
17+
)
18+
19+
from fastchat.conversation import Conversation, SeparatorStyle
20+
from fastchat.model.model_adapter import get_conversation_template
21+
22+
# Modify OpenAI's API key and API base to use vLLM's API server.
23+
openai.api_key = "EMPTY"
24+
openai.api_base = "http://localhost:8000/v1"
25+
26+
api_base_pool = []
27+
28+
# List models API
29+
for i in range(10):
30+
openai.api_base = "http://localhost:800{}/v1".format(i)
31+
try:
32+
models = openai.Model.list()["data"][0]["id"]
33+
print(openai.api_base, models)
34+
api_base_pool.append(openai.api_base)
35+
except:
36+
break
37+
38+
print("API base pool: ", api_base_pool)
39+
40+
parser = argparse.ArgumentParser()
41+
parser.add_argument("--data_path", type=str)
42+
parser.add_argument("--output_path", type=str)
43+
parser.add_argument("--num_threads", type=int, default=256)
44+
parser.add_argument("--temperature", type=float, default=0.3)
45+
parser.add_argument("--max_tokens", type=int, default=2048)
46+
parser.add_argument("--chat", action="store_true")
47+
args = parser.parse_args()
48+
49+
# Assuming the ShareGPT format
50+
data = json.load(open(args.data_path, "r"))
51+
52+
def generate_data(messages, idx):
53+
try:
54+
# load balanced
55+
openai.api_base = api_base_pool[idx % len(api_base_pool)]
56+
model_name=openai.Model.list()["data"][0]["id"]
57+
58+
if args.chat:
59+
converted_messages = []
60+
output_messages = []
61+
if messages[0]["from"] == "system":
62+
converted_messages.append(
63+
{
64+
"role": "system",
65+
"content": messages[0]["text"],
66+
}
67+
)
68+
output_messages.append(messages[0])
69+
messages = messages[1:]
70+
for message in messages[::2]:
71+
if message["from"] != "human":
72+
return
73+
converted_messages.append(
74+
{
75+
"role": "user",
76+
"content": message["value"],
77+
}
78+
)
79+
try:
80+
response = openai.ChatCompletion.create(
81+
model=model_name,
82+
messages=converted_messages,
83+
max_tokens=args.max_tokens,
84+
temperature=args.temperature,
85+
)
86+
if response.choices[0]['finish_reason'] == "length":
87+
break
88+
response = response.choices[0]['message']['content'].strip()
89+
output_messages.append(message)
90+
output_messages.append(
91+
{
92+
"from": "gpt",
93+
"value": response,
94+
}
95+
)
96+
converted_messages.append(
97+
{
98+
"role": "assistant",
99+
"content": response,
100+
}
101+
)
102+
except:
103+
break
104+
if len(output_messages) == 0:
105+
return
106+
with open(args.output_path, "a") as f:
107+
# write in share gpt format
108+
f.write(json.dumps({"conversations": output_messages}) + "\n")
109+
else:
110+
conv = get_conversation_template(model_name)
111+
if messages[0]["from"] == "system":
112+
conv.system_message = messages[0]["text"]
113+
messages = messages[1:]
114+
conv.append_message(conv.roles[0], messages[0]["value"])
115+
conv.append_message(conv.roles[1], None)
116+
prompt = conv.get_prompt()
117+
118+
response = openai.Completion.create(
119+
model=model_name,
120+
prompt=prompt,
121+
max_tokens=args.max_tokens,
122+
temperature=args.temperature,
123+
ignore_eos=True,
124+
skip_special_tokens=False,
125+
spaces_between_special_tokens=False,
126+
)
127+
response = response.choices[0]['text'].strip()
128+
with open(args.output_path, "a") as f:
129+
# write in share gpt format
130+
f.write(json.dumps({"text": prompt+response}) + "\n")
131+
except Exception as e:
132+
print(e)
133+
print(prompt)
134+
print("Failed to generate data")
135+
136+
# if output_path exists, count the number of lines and skip the first n data
137+
start = 0
138+
if os.path.exists(args.output_path):
139+
with open(args.output_path, "r") as f:
140+
start = len(f.readlines())
141+
print("Skip first {} data".format(start))
142+
143+
with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_threads) as executor:
144+
futures = []
145+
for idx, sample in enumerate(data[start:]):
146+
future = executor.submit(
147+
generate_data,
148+
sample["conversations"],
149+
idx,
150+
)
151+
futures.append(future)
152+
153+
for future in tqdm.tqdm(
154+
concurrent.futures.as_completed(futures), total=len(futures)
155+
):
156+
future.result()

0 commit comments

Comments
 (0)