Skip to content

Commit bb87a01

Browse files
authored
Update roy.py
1 parent 9ccc039 commit bb87a01

File tree

1 file changed

+33
-54
lines changed

1 file changed

+33
-54
lines changed

roy.py

Lines changed: 33 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -38,24 +38,14 @@ def wrapper(*args, **kwargs):
3838
return result
3939
return wrapper
4040

41-
def truncate_string(row, char_limit, variable_str, constant_str):
42-
if not (isinstance(row[variable_str], str) and isinstance(row[constant_str], str)):
43-
return ""
44-
if len(row[constant_str]) >= char_limit:
45-
return ""
46-
trimmed_length = char_limit - len(row[constant_str])
47-
return row[variable_str][:trimmed_length]
48-
4941
def process_code_string(s):
5042
if '>>>' not in s:
5143
return s
52-
5344
def replace_line_prefix(match):
5445
prefix = match.group(1)
5546
if prefix in [">>> ", "... "]:
5647
return ""
5748
return "# " + match.group(0)
58-
5949
pattern = r"^(>>> |... |\S+.*$)"
6050
return re.sub(pattern, replace_line_prefix, s, flags=re.MULTILINE)
6151

@@ -69,7 +59,7 @@ def extract_code_block(s, is_python):
6959
code = ''
7060
for match in matches:
7161
is_python = identify_lang(match) if is_python is None else is_python
72-
code += match[1] if is_python else re.sub(r'(?<!!)^', '!', match[1], flags=re.MULTILINE)
62+
code += match[1] if is_python else re.sub(r'^(?![!])', '!', match[1], flags=re.MULTILINE)
7363
return code.rstrip()
7464

7565
def process_markdown_data(df):
@@ -80,6 +70,13 @@ def process_markdown_data(df):
8070
return df
8171

8272
def process_docstr_data(df):
73+
def truncate_string(row, char_limit, variable_str, constant_str):
74+
if not (isinstance(row[variable_str], str) and isinstance(row[constant_str], str)):
75+
return ""
76+
if len(row[constant_str]) >= char_limit:
77+
return ""
78+
trimmed_length = char_limit - len(row[constant_str])
79+
return row[variable_str][:trimmed_length]
8380
df = df[df['docstring'].str.contains('```')]
8481
df = df[~df['filepath'].apply(lambda x: x.split('/')[-1]).str.startswith('TF')]
8582
df.reset_index(drop=True, inplace=True)
@@ -90,39 +87,33 @@ def process_docstr_data(df):
9087
df['retrieved_docstr'] = df.apply(lambda row: f"{row['type']} `{row['filepath'].split('/')[-1]}` ({row['filepath']}):\n'''\n{row['docstring']}...\n'''", axis=1)
9188
return df
9289

90+
default_config_for_RM = {
91+
'markdown': {
92+
'filename_key': 'hfmd_20230927192215',
93+
'process_data': process_markdown_data
94+
},
95+
'huggingface': {
96+
'filename_key': 'hfds_20230927191331',
97+
'process_data': process_docstr_data
98+
},
99+
}
93100

94101
def edit_code_in_terminal(initial_text):
95102
kb = KeyBindings()
96103
result = {'text': initial_text}
97-
98104
@kb.add('s-tab')
99105
def _(event):
100106
result['text'] = event.app.current_buffer.text
101107
event.app.exit()
102-
103108
style = Style.from_dict({
104109
'': '#ffad00',
105110
'prompt': 'bg:#ff0000 #ffff00',
106111
})
107-
108112
session = PromptSession(lexer=PygmentsLexer(Python3Lexer), key_bindings=kb, style=style)
109113
session.prompt('\n--- Press shift+tab when done ---\n', multiline=True, default=initial_text)
110-
111114
result_text = result['text']
112-
113115
return result_text
114116

115-
default_config_for_RM = {
116-
'markdown': {
117-
'filename_key': 'hfmd_20230927192215',
118-
'process_data': process_markdown_data
119-
},
120-
'huggingface': {
121-
'filename_key': 'hfds_20230927191331',
122-
'process_data': process_docstr_data
123-
},
124-
}
125-
126117
def identify_lang(match): # stub
127118
if 'py' in match[0]:
128119
is_python = True
@@ -138,9 +129,8 @@ def identify_lang(match): # stub
138129
is_python = True
139130
return is_python
140131

141-
142132
class VirtualEnvironment:
143-
def __init__(self, venv_path='venv4gen'):
133+
def __init__(self, venv_path='venvRoy'):
144134
self.venv_path = venv_path
145135
try:
146136
if not os.path.exists(self.venv_path):
@@ -151,6 +141,8 @@ def __init__(self, venv_path='venv4gen'):
151141
else:
152142
self.python_executable = os.path.join(venv_path, "bin", "python")
153143
self.pip_executable = os.path.join(venv_path, "bin", "pip")
144+
subprocess.run(f'{self.python_executable} -V')
145+
subprocess.run(f'{self.pip_executable} -V')
154146
except:
155147
log("Warning: Failed to create or locate virtual environment. Using default system python and pip.")
156148
self.python_executable = "python"
@@ -277,7 +269,6 @@ def _constrained_beam(self, input_beam, constraint, prohibits, num_beams, cache_
277269
best_postfixed = (None, None, float('-inf'), None)
278270
best_compatibility = float('-inf')
279271
for i in range(max_new_tokens):
280-
best_voluntary = (None, None, float('-inf'), None)
281272
new_beams = []
282273
for beam in beams:
283274
beam_input_ids, beam_output_tokens, beam_score, beam_kv = beam
@@ -302,17 +293,14 @@ def _constrained_beam(self, input_beam, constraint, prohibits, num_beams, cache_
302293
new_input_ids = next_token_id.unsqueeze(0).unsqueeze(0)
303294
new_output_tokens = beam_output_tokens + [next_token_id.item()]
304295
new_score = ((beam_score * (len(beam_output_tokens) + norm_factor)) + next_score.item()) / (len(new_output_tokens) + norm_factor)
305-
if next_token_id == self.tokenizer.eos_token_id:
306-
continue
307-
elif all(new_output_tokens[-len(p):] != p for p in prohibits):
296+
if all(new_output_tokens[-len(p):] != p for p in prohibits) and (next_token_id != self.tokenizer.eos_token_id):
308297
new_beam = (new_input_ids, new_output_tokens, new_score, new_kv)
309298
new_beams.append(new_beam)
310299
new_beams = sorted(new_beams, key=lambda x: x[2], reverse=True)[:num_beams]
311-
if any(new_output_tokens[-len(sublist):] == sublist for sublist in required_tokens) and (new_score > best_voluntary[2]):
312-
best_voluntary = new_beam
313-
if best_voluntary[2] >= new_beams[-1][2]:
314-
torch.save(best_voluntary[-1], fn_to_save)
315-
return (best_voluntary[0], best_voluntary[1][adhoc:], best_voluntary[2], cache_fn)
300+
for new_beam in new_beams:
301+
if any(new_beam[1][-len(sublist):] == sublist for sublist in required_tokens):
302+
torch.save(new_beam[-1], fn_to_save)
303+
return (new_beam[0], new_beam[1][adhoc:], new_beam[2], cache_fn)
316304
beams = new_beams
317305
torch.save(best_postfixed[-1], fn_to_save)
318306
return (best_postfixed[0], best_postfixed[1][adhoc:], best_postfixed[2], cache_fn)
@@ -337,19 +325,15 @@ def _unconstrained_beam(self, input_beam, max_new_tokens, prohibits, num_beams,
337325
new_logits = new_outputs.logits[:, -1, :]
338326
new_kv = new_outputs.past_key_values
339327
topk = torch.topk(new_logits, num_beams)
340-
list_next_token_id = topk.indices[0]
341-
list_next_score = topk.values[0]
342-
if (self.tokenizer.eos_token_id in list_next_token_id) and (new_score > best_eos[2]):
343-
best_eos = beam
344-
patience = patience_limit
345-
continue
346-
for next_token_id, next_score in zip(list_next_token_id, list_next_score):
328+
for next_token_id, next_score in zip(topk.indices[0], topk.values[0]):
347329
new_input_ids = next_token_id.unsqueeze(0).unsqueeze(0)
348330
new_output_tokens = beam_output_tokens + [next_token_id.item()]
349331
new_score = ((beam_score * (len(beam_output_tokens) + norm_factor)) + next_score.item()) / (len(new_output_tokens) + norm_factor)
350-
if all(new_output_tokens[-len(p):] != p for p in prohibits):
351-
new_beam = (new_input_ids, new_output_tokens, new_score, new_kv)
352-
new_beams.append(new_beam)
332+
if (next_token_id == self.tokenizer.eos_token_id) and (new_score > best_eos[2]):
333+
best_eos = beam
334+
patience = patience_limit
335+
elif all(new_output_tokens[-len(p):] != p for p in prohibits):
336+
new_beams.append((new_input_ids, new_output_tokens, new_score, new_kv))
353337
new_beams = sorted(new_beams, key=lambda x: x[2], reverse=True)[:num_beams]
354338
beams = new_beams
355339
result = max([best_eos] + beams, key=lambda x:x[2])
@@ -366,7 +350,6 @@ def tokenize_constraints(s):
366350
input_ids = [sub[1:] if sub and sub[0] == 29871 else sub for sub in input_ids]
367351
input_ids[len(s):] = [i[1:] for i in input_ids[len(s):]]
368352
return [list(x) for x in set(tuple(x) for x in input_ids)]
369-
370353
if len(template) == 1:
371354
if isinstance(template[0], int):
372355
return [(template[0], None)]
@@ -403,7 +386,6 @@ def generate(self, input_txt, template = (('\n```python', '\n```sh'), '\n```'),
403386
i_beam = self._unconstrained_beam(i_beam, max_new_tokens = constraint[0], prohibits=prohibits, num_beams=num_beams, cache_fn=cache_fn)
404387
torch.cuda.empty_cache()
405388
result += i_beam[1]
406-
407389
else:
408390
i_beam = self._constrained_beam(i_beam, constraint = constraint, prohibits=prohibits, num_beams=num_beams, cache_fn=cache_fn)
409391
torch.cuda.empty_cache()
@@ -414,13 +396,10 @@ class Roy:
414396
def __init__(self, config=None):
415397
if config is None:
416398
config = {}
417-
418399
self.template = config.get('template', "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n")
419-
420400
self._venv = None
421401
self._lm = None
422402
self._rm = None
423-
424403
self.execute = trace_method(config.get('execute', self.venv.execute))
425404
self.generate = trace_method(config.get('generate', self.lm.generate))
426405
self.retrieve = trace_method(config.get('retrieve', self.rm.retrieve))
@@ -453,4 +432,4 @@ def format(self, instruction, data={}):
453432
elif isinstance(data, (dict, pd.Series)):
454433
return template.format(**data)
455434
else:
456-
raise ValueError("Unsupported data type. Data must be a dict, Series, or DataFrame.")
435+
raise ValueError("Unsupported data type. Data must be a dict, Series, or DataFrame.")

0 commit comments

Comments
 (0)