11import torch
2- from transformers import LlamaForCausalLM , LlamaTokenizer , LlamaConfig
3- from accelerate import init_empty_weights
2+ from transformers import LlamaForCausalLM
43try :
54 from peft import PeftModel
65except :
76 PeftModel = None
87
98from .transformersbot import TransformersChatBOT
10- from .utils import load_checkpoint_and_dispatch_from_s3
119
1210class BaizeBOT (TransformersChatBOT ):
1311 def __init__ (self , config ):
@@ -17,7 +15,7 @@ def __init__(self, config):
1715 )
1816 if config .base_model is None :
1917 raise ValueError (
20- "Base model's path of Baize should be set."
18+ "Base model(llama) 's path of Baize should be set."
2119 )
2220 super (BaizeBOT , self ).__init__ (config )
2321
@@ -115,14 +113,8 @@ def process_response(self, response):
115113 response = response [: response .index ("[|Human|]" )].strip ()
116114 if "[|AI|]" in response :
117115 response = response [: response .index ("[|AI|]" )].strip ()
118-
119- return response .strip ()
120-
121- def load_tokenizer (self ):
122- self .tokenizer = LlamaTokenizer .from_pretrained (
123- self .config .tokenizer_path
124- )
125-
116+ return response .strip (" " )
117+
126118 def load_model (self ):
127119
128120 llama = self .model_cls .from_pretrained (
@@ -139,6 +131,9 @@ def load_from_s3(self):
139131 import io
140132 import json
141133 from petrel_client .client import Client
134+ from accelerate import init_empty_weights
135+ from transformers import LlamaConfig
136+ from .utils import load_checkpoint_and_dispatch_from_s3
142137 client = Client ()
143138
144139 # get config
0 commit comments