Skip to content

Commit 89b5042

Browse files
authored
Merge pull request #2 from OpenLMLab/main
Release
2 parents 45e9d7b + b727c4a commit 89b5042

File tree

9 files changed

+75
-20
lines changed

9 files changed

+75
-20
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ jobs:
5555
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
5656
with:
5757
tag_name: ui-v${{ steps.version.outputs.value }}
58-
release_name: Release refs/heads/ui
58+
release_name: Release refs/heads/ui-v${{ steps.version.outputs.value }}
5959
body: UI version ${{ steps.version.outputs.value }}.
6060
draft: false
6161
prerelease: false

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ python server.py --pretrained_path fnlp/moss-moon-003-sft
6262
- [GODEL](https://github.com/microsoft/GODEL)
6363
- [GODEL-v1_1-base-seq2seq](https://huggingface.co/microsoft/GODEL-v1_1-base-seq2seq)
6464
- [GODEL-v1_1-large-seq2seq](https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq)
65+
- [StableLM]
66+
- [stablelm-tuned-alpha-3b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b)
67+
- [stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b)
6568

6669
### 添加自己的模型
6770

config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
"microsoft/GODEL-v1_1-large-seq2seq": "godel",
3030
# belle
3131
"BelleGroup/BELLE-7B-2M": "belle",
32+
# stablelm
33+
"stabilityai/stablelm-tuned-alpha-3b": "stablelm",
34+
"stabilityai/stablelm-tuned-alpha-7b": "stablelm",
3235
}
3336

3437
DTYPE_DICT = {

generator/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import importlib
23
import inspect
34

@@ -8,7 +9,11 @@ def choose_bot(config):
89
classes = inspect.getmembers(mod, inspect.isclass)
910
name, bot_cls = None, None
1011
for name, bot_cls in classes:
11-
if issubclass(bot_cls, ChatBOT):
12+
_, filename = os.path.split(inspect.getsourcefile(bot_cls))
13+
file_mod, _ = os.path.splitext(filename)
14+
# bot_cls may be class that is imported from other files
15+
# ex. ChatBOT
16+
if file_mod == config.type and issubclass(bot_cls, ChatBOT):
1217
break
1318

1419
print(f"Choose ChatBOT: {name}")

generator/baize.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import torch
2-
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
3-
from accelerate import init_empty_weights
2+
from transformers import LlamaForCausalLM
43
try:
54
from peft import PeftModel
65
except:
76
PeftModel = None
87

98
from .transformersbot import TransformersChatBOT
10-
from .utils import load_checkpoint_and_dispatch_from_s3
119

1210
class BaizeBOT(TransformersChatBOT):
1311
def __init__(self, config):
@@ -17,7 +15,7 @@ def __init__(self, config):
1715
)
1816
if config.base_model is None:
1917
raise ValueError(
20-
"Base model's path of Baize should be set."
18+
"Base model(llama)'s path of Baize should be set."
2119
)
2220
super(BaizeBOT, self).__init__(config)
2321

@@ -115,14 +113,8 @@ def process_response(self, response):
115113
response = response[: response.index("[|Human|]")].strip()
116114
if "[|AI|]" in response:
117115
response = response[: response.index("[|AI|]")].strip()
118-
119-
return response.strip()
120-
121-
def load_tokenizer(self):
122-
self.tokenizer = LlamaTokenizer.from_pretrained(
123-
self.config.tokenizer_path
124-
)
125-
116+
return response.strip(" ")
117+
126118
def load_model(self):
127119

128120
llama = self.model_cls.from_pretrained(
@@ -139,6 +131,9 @@ def load_from_s3(self):
139131
import io
140132
import json
141133
from petrel_client.client import Client
134+
from accelerate import init_empty_weights
135+
from transformers import LlamaConfig
136+
from .utils import load_checkpoint_and_dispatch_from_s3
142137
client = Client()
143138

144139
# get config

generator/belle.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import torch
2-
from transformers import BloomForCausalLM, AutoConfig
3-
from accelerate import init_empty_weights
2+
from transformers import BloomForCausalLM
43

54
from .transformersbot import TransformersChatBOT
6-
from .utils import load_checkpoint_and_dispatch_from_s3
75

86
class BELLEBOT(TransformersChatBOT):
97
def __init__(self, config):
@@ -42,7 +40,9 @@ def load_from_s3(self):
4240
import io
4341
import json
4442
from petrel_client.client import Client
45-
from tqdm import tqdm
43+
from accelerate import init_empty_weights
44+
from transformers import AutoConfig
45+
from .utils import load_checkpoint_and_dispatch_from_s3
4646
client = Client()
4747

4848
# get model_index

generator/stablelm.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from transformers import GPTNeoXForCausalLM, StoppingCriteria, StoppingCriteriaList
2+
3+
from .transformersbot import TransformersChatBOT
4+
5+
class StableLMBOT(TransformersChatBOT):
6+
def __init__(self, config):
7+
super(StableLMBOT, self).__init__(config)
8+
9+
@property
10+
def model_cls(self):
11+
return GPTNeoXForCausalLM
12+
13+
def extra_settings(self):
14+
return {
15+
"stopping_criteria": StoppingCriteriaList([StopOnTokens()])
16+
}
17+
18+
def get_prompt(self, query):
19+
prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
20+
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
21+
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
22+
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
23+
- StableLM will refuse to participate in anything that could harm a human.
24+
"""
25+
prompt_dict = {
26+
"BOT": "<|ASSISTANT|>{}",
27+
"HUMAN": "<|USER|>{}"
28+
}
29+
for q in query:
30+
prompt += prompt_dict[q["role"]].format(q["content"])
31+
prompt += "<|ASSISTANT|>"
32+
33+
return prompt
34+
35+
@property
36+
def no_split_module_classes(self):
37+
return ["GPTNeoXLayer"]
38+
39+
class StopOnTokens(StoppingCriteria):
40+
def __call__(self, input_ids, scores, **kwargs) -> bool:
41+
stop_ids = [50278, 50279, 50277, 1, 0]
42+
for stop_id in stop_ids:
43+
if input_ids[0][-1] == stop_id:
44+
return True
45+
return False

generator/transformersbot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
import torch
44
from transformers import AutoTokenizer, AutoConfig
55
from transformers.models.auto.modeling_auto import _BaseAutoModelClass
6-
from accelerate import init_empty_weights
76

87
from .chatbot import ChatBOT
9-
from .utils import load_checkpoint_and_dispatch_from_s3
108

119
class TransformersChatBOT(ChatBOT):
1210
"""
@@ -115,6 +113,8 @@ def load_from_s3(self):
115113
import io
116114
import json
117115
from petrel_client.client import Client
116+
from accelerate import init_empty_weights
117+
from .utils import load_checkpoint_and_dispatch_from_s3
118118
client = Client()
119119

120120
# get model_index

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
fastapi
2+
uvicorn
3+
transformers
4+
accelerate

0 commit comments

Comments
 (0)