77import torch
88from torch .nn import functional as F
99
10+
11+ def end_overlap (a , b ):
12+ for i in reversed (range (1 , len (a ) + 1 )):
13+ if b .startswith (a [- i :]):
14+ return i
15+ return 0
16+
1017class PIPELINE_ARGS ():
11- def __init__ (self , temperature = 1.0 , top_p = 0.85 , top_k = 0 , alpha_frequency = 0.2 , alpha_presence = 0.2 , token_ban = [], token_stop = [], chunk_len = 256 ):
18+ def __init__ (self ,
19+ temperature = 1.0 ,
20+ top_p = 0.85 ,
21+ top_k = 0 ,
22+ alpha_frequency = 0.2 ,
23+ alpha_presence = 0.2 ,
24+ token_ban = None ,
25+ token_stop = None ,
26+ stop_words = None ,
27+ chunk_len = 256
28+ ):
29+
30+ token_ban = token_ban or []
31+ token_stop = token_stop or []
32+ stop_words = stop_words or []
33+
1234 self .temperature = temperature
1335 self .top_p = top_p
1436 self .top_k = top_k
1537 self .alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
1638 self .alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
1739 self .token_ban = token_ban # ban the generation of some tokens
1840 self .token_stop = token_stop # stop generation whenever you see any token here
41+ self .stop_words = stop_words # stop generation whenever you see any token here
1942 self .chunk_len = chunk_len # split input into chunks to save VRAM (shorter -> slower)
2043
2144class PIPELINE ():
@@ -77,12 +100,23 @@ def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
77100 probs = probs ** (1.0 / temperature )
78101 out = torch .multinomial (probs , num_samples = 1 )[0 ]
79102 return int (out )
80-
81- def generate (self , ctx , token_count = 100 , args = PIPELINE_ARGS (), callback = None , state = None ):
103+
104+ def generate (self , * args , callback = None , ** kwargs ):
105+ outstr = []
106+ for delta in self .igenerate (* args , ** kwargs ):
107+ outstr += [delta ]
108+ if callback :
109+ callback (delta )
110+ return '' .join (outstr )
111+
112+ def igenerate (self , ctx , token_count = 100 , args = PIPELINE_ARGS (), state = None ):
82113 all_tokens = []
83114 out_last = 0
84115 out_str = ''
85116 occurrence = {}
117+
118+ stopword_checker = self .check_stopwords (args .stop_words )
119+ next (stopword_checker )
86120 for i in range (token_count ):
87121
88122 # forward & adjust prob.
@@ -108,9 +142,57 @@ def generate(self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, st
108142
109143 # output
110144 tmp = self .decode (all_tokens [out_last :])
145+ if len (all_tokens )== 1 :
146+ tmp = tmp [1 :] # strip leading space
147+ if tmp == '' :
148+ continue
111149 if '\ufffd ' not in tmp : # is valid utf-8 string?
112- if callback :
113- callback (tmp )
114- out_str += tmp
150+
151+ try :
152+ tmp = stopword_checker .send (tmp )
153+ except StopIteration :
154+ break
115155 out_last = i + 1
116- return out_str
156+
157+ if tmp is None :
158+ continue
159+ yield tmp
160+ out_str += tmp
161+ out_last = i + 1
162+
163+ @staticmethod
164+ def check_stopwords (stop_words ):
165+
166+ longest_stopword = 0 if len (stop_words )== 0 else max (map (len , stop_words ))
167+ chunk = ""
168+ delta = True
169+ # yield
170+ to_yield = None
171+ while delta :
172+ delta = yield to_yield
173+ chunk = chunk + delta
174+
175+ if longest_stopword == 0 :
176+ # nothing to check just passthrough
177+ to_yield = delta
178+ continue
179+ if chunk == '' :
180+ to_yield = None
181+ continue
182+ if any (map (lambda stop_word : chunk .startswith (stop_word ), stop_words )):
183+ return
184+
185+ if start_idx := max (map (lambda stop_word : end_overlap (chunk , stop_word ), stop_words )):
186+ if start_idx > longest_stopword :
187+ start_idx = longest_stopword # can no longer be a stopword so cut it down
188+ good , chunk = chunk [:- start_idx ], chunk [- start_idx :]
189+ if good :
190+ to_yield = good
191+ continue
192+
193+ to_yield = None
194+ continue
195+
196+ out = chunk
197+ chunk = ""
198+ to_yield = out
0 commit comments