@@ -38,24 +38,14 @@ def wrapper(*args, **kwargs):
3838 return result
3939 return wrapper
4040
41- def truncate_string (row , char_limit , variable_str , constant_str ):
42- if not (isinstance (row [variable_str ], str ) and isinstance (row [constant_str ], str )):
43- return ""
44- if len (row [constant_str ]) >= char_limit :
45- return ""
46- trimmed_length = char_limit - len (row [constant_str ])
47- return row [variable_str ][:trimmed_length ]
48-
4941def process_code_string (s ):
5042 if '>>>' not in s :
5143 return s
52-
5344 def replace_line_prefix (match ):
5445 prefix = match .group (1 )
5546 if prefix in [">>> " , "... " ]:
5647 return ""
5748 return "# " + match .group (0 )
58-
5949 pattern = r"^(>>> |... |\S+.*$)"
6050 return re .sub (pattern , replace_line_prefix , s , flags = re .MULTILINE )
6151
@@ -69,7 +59,7 @@ def extract_code_block(s, is_python):
6959 code = ''
7060 for match in matches :
7161 is_python = identify_lang (match ) if is_python is None else is_python
72- code += match [1 ] if is_python else re .sub (r'(?<!!)^ ' , '!' , match [1 ], flags = re .MULTILINE )
62+ code += match [1 ] if is_python else re .sub (r'^(?![!]) ' , '!' , match [1 ], flags = re .MULTILINE )
7363 return code .rstrip ()
7464
7565def process_markdown_data (df ):
@@ -80,6 +70,13 @@ def process_markdown_data(df):
8070 return df
8171
8272def process_docstr_data (df ):
73+ def truncate_string (row , char_limit , variable_str , constant_str ):
74+ if not (isinstance (row [variable_str ], str ) and isinstance (row [constant_str ], str )):
75+ return ""
76+ if len (row [constant_str ]) >= char_limit :
77+ return ""
78+ trimmed_length = char_limit - len (row [constant_str ])
79+ return row [variable_str ][:trimmed_length ]
8380 df = df [df ['docstring' ].str .contains ('```' )]
8481 df = df [~ df ['filepath' ].apply (lambda x : x .split ('/' )[- 1 ]).str .startswith ('TF' )]
8582 df .reset_index (drop = True , inplace = True )
@@ -90,39 +87,33 @@ def process_docstr_data(df):
9087 df ['retrieved_docstr' ] = df .apply (lambda row : f"{ row ['type' ]} `{ row ['filepath' ].split ('/' )[- 1 ]} ` ({ row ['filepath' ]} ):\n '''\n { row ['docstring' ]} ...\n '''" , axis = 1 )
9188 return df
9289
90+ default_config_for_RM = {
91+ 'markdown' : {
92+ 'filename_key' : 'hfmd_20230927192215' ,
93+ 'process_data' : process_markdown_data
94+ },
95+ 'huggingface' : {
96+ 'filename_key' : 'hfds_20230927191331' ,
97+ 'process_data' : process_docstr_data
98+ },
99+ }
93100
94101def edit_code_in_terminal (initial_text ):
95102 kb = KeyBindings ()
96103 result = {'text' : initial_text }
97-
98104 @kb .add ('s-tab' )
99105 def _ (event ):
100106 result ['text' ] = event .app .current_buffer .text
101107 event .app .exit ()
102-
103108 style = Style .from_dict ({
104109 '' : '#ffad00' ,
105110 'prompt' : 'bg:#ff0000 #ffff00' ,
106111 })
107-
108112 session = PromptSession (lexer = PygmentsLexer (Python3Lexer ), key_bindings = kb , style = style )
109113 session .prompt ('\n --- Press shift+tab when done ---\n ' , multiline = True , default = initial_text )
110-
111114 result_text = result ['text' ]
112-
113115 return result_text
114116
115- default_config_for_RM = {
116- 'markdown' : {
117- 'filename_key' : 'hfmd_20230927192215' ,
118- 'process_data' : process_markdown_data
119- },
120- 'huggingface' : {
121- 'filename_key' : 'hfds_20230927191331' ,
122- 'process_data' : process_docstr_data
123- },
124- }
125-
126117def identify_lang (match ): # stub
127118 if 'py' in match [0 ]:
128119 is_python = True
@@ -138,9 +129,8 @@ def identify_lang(match): # stub
138129 is_python = True
139130 return is_python
140131
141-
142132class VirtualEnvironment :
143- def __init__ (self , venv_path = 'venv4gen ' ):
133+ def __init__ (self , venv_path = 'venvRoy ' ):
144134 self .venv_path = venv_path
145135 try :
146136 if not os .path .exists (self .venv_path ):
@@ -151,6 +141,8 @@ def __init__(self, venv_path='venv4gen'):
151141 else :
152142 self .python_executable = os .path .join (venv_path , "bin" , "python" )
153143 self .pip_executable = os .path .join (venv_path , "bin" , "pip" )
144+ subprocess .run (f'{ self .python_executable } -V' )
145+ subprocess .run (f'{ self .pip_executable } -V' )
154146 except :
155147 log ("Warning: Failed to create or locate virtual environment. Using default system python and pip." )
156148 self .python_executable = "python"
@@ -277,7 +269,6 @@ def _constrained_beam(self, input_beam, constraint, prohibits, num_beams, cache_
277269 best_postfixed = (None , None , float ('-inf' ), None )
278270 best_compatibility = float ('-inf' )
279271 for i in range (max_new_tokens ):
280- best_voluntary = (None , None , float ('-inf' ), None )
281272 new_beams = []
282273 for beam in beams :
283274 beam_input_ids , beam_output_tokens , beam_score , beam_kv = beam
@@ -302,17 +293,14 @@ def _constrained_beam(self, input_beam, constraint, prohibits, num_beams, cache_
302293 new_input_ids = next_token_id .unsqueeze (0 ).unsqueeze (0 )
303294 new_output_tokens = beam_output_tokens + [next_token_id .item ()]
304295 new_score = ((beam_score * (len (beam_output_tokens ) + norm_factor )) + next_score .item ()) / (len (new_output_tokens ) + norm_factor )
305- if next_token_id == self .tokenizer .eos_token_id :
306- continue
307- elif all (new_output_tokens [- len (p ):] != p for p in prohibits ):
296+ if all (new_output_tokens [- len (p ):] != p for p in prohibits ) and (next_token_id != self .tokenizer .eos_token_id ):
308297 new_beam = (new_input_ids , new_output_tokens , new_score , new_kv )
309298 new_beams .append (new_beam )
310299 new_beams = sorted (new_beams , key = lambda x : x [2 ], reverse = True )[:num_beams ]
311- if any (new_output_tokens [- len (sublist ):] == sublist for sublist in required_tokens ) and (new_score > best_voluntary [2 ]):
312- best_voluntary = new_beam
313- if best_voluntary [2 ] >= new_beams [- 1 ][2 ]:
314- torch .save (best_voluntary [- 1 ], fn_to_save )
315- return (best_voluntary [0 ], best_voluntary [1 ][adhoc :], best_voluntary [2 ], cache_fn )
300+ for new_beam in new_beams :
301+ if any (new_beam [1 ][- len (sublist ):] == sublist for sublist in required_tokens ):
302+ torch .save (new_beam [- 1 ], fn_to_save )
303+ return (new_beam [0 ], new_beam [1 ][adhoc :], new_beam [2 ], cache_fn )
316304 beams = new_beams
317305 torch .save (best_postfixed [- 1 ], fn_to_save )
318306 return (best_postfixed [0 ], best_postfixed [1 ][adhoc :], best_postfixed [2 ], cache_fn )
@@ -337,19 +325,15 @@ def _unconstrained_beam(self, input_beam, max_new_tokens, prohibits, num_beams,
337325 new_logits = new_outputs .logits [:, - 1 , :]
338326 new_kv = new_outputs .past_key_values
339327 topk = torch .topk (new_logits , num_beams )
340- list_next_token_id = topk .indices [0 ]
341- list_next_score = topk .values [0 ]
342- if (self .tokenizer .eos_token_id in list_next_token_id ) and (new_score > best_eos [2 ]):
343- best_eos = beam
344- patience = patience_limit
345- continue
346- for next_token_id , next_score in zip (list_next_token_id , list_next_score ):
328+ for next_token_id , next_score in zip (topk .indices [0 ], topk .values [0 ]):
347329 new_input_ids = next_token_id .unsqueeze (0 ).unsqueeze (0 )
348330 new_output_tokens = beam_output_tokens + [next_token_id .item ()]
349331 new_score = ((beam_score * (len (beam_output_tokens ) + norm_factor )) + next_score .item ()) / (len (new_output_tokens ) + norm_factor )
350- if all (new_output_tokens [- len (p ):] != p for p in prohibits ):
351- new_beam = (new_input_ids , new_output_tokens , new_score , new_kv )
352- new_beams .append (new_beam )
332+ if (next_token_id == self .tokenizer .eos_token_id ) and (new_score > best_eos [2 ]):
333+ best_eos = beam
334+ patience = patience_limit
335+ elif all (new_output_tokens [- len (p ):] != p for p in prohibits ):
336+ new_beams .append ((new_input_ids , new_output_tokens , new_score , new_kv ))
353337 new_beams = sorted (new_beams , key = lambda x : x [2 ], reverse = True )[:num_beams ]
354338 beams = new_beams
355339 result = max ([best_eos ] + beams , key = lambda x :x [2 ])
@@ -366,7 +350,6 @@ def tokenize_constraints(s):
366350 input_ids = [sub [1 :] if sub and sub [0 ] == 29871 else sub for sub in input_ids ]
367351 input_ids [len (s ):] = [i [1 :] for i in input_ids [len (s ):]]
368352 return [list (x ) for x in set (tuple (x ) for x in input_ids )]
369-
370353 if len (template ) == 1 :
371354 if isinstance (template [0 ], int ):
372355 return [(template [0 ], None )]
@@ -403,7 +386,6 @@ def generate(self, input_txt, template = (('\n```python', '\n```sh'), '\n```'),
403386 i_beam = self ._unconstrained_beam (i_beam , max_new_tokens = constraint [0 ], prohibits = prohibits , num_beams = num_beams , cache_fn = cache_fn )
404387 torch .cuda .empty_cache ()
405388 result += i_beam [1 ]
406-
407389 else :
408390 i_beam = self ._constrained_beam (i_beam , constraint = constraint , prohibits = prohibits , num_beams = num_beams , cache_fn = cache_fn )
409391 torch .cuda .empty_cache ()
@@ -414,13 +396,10 @@ class Roy:
414396 def __init__ (self , config = None ):
415397 if config is None :
416398 config = {}
417-
418399 self .template = config .get ('template' , "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n \n ### Instruction:\n {instruction}\n \n ### Response:\n " )
419-
420400 self ._venv = None
421401 self ._lm = None
422402 self ._rm = None
423-
424403 self .execute = trace_method (config .get ('execute' , self .venv .execute ))
425404 self .generate = trace_method (config .get ('generate' , self .lm .generate ))
426405 self .retrieve = trace_method (config .get ('retrieve' , self .rm .retrieve ))
@@ -453,4 +432,4 @@ def format(self, instruction, data={}):
453432 elif isinstance (data , (dict , pd .Series )):
454433 return template .format (** data )
455434 else :
456- raise ValueError ("Unsupported data type. Data must be a dict, Series, or DataFrame." )
435+ raise ValueError ("Unsupported data type. Data must be a dict, Series, or DataFrame." )
0 commit comments