-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerator.py
More file actions
459 lines (379 loc) · 21.1 KB
/
generator.py
File metadata and controls
459 lines (379 loc) · 21.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
import re
import os
import time
import json
import torch
import requests
import logging
import numpy as np
from tqdm import tqdm
import torch.nn as nn
# from openai import OpenAI
import multiprocessing as mp
from functools import partial
from multiprocessing import Pool
from vllm import LLM, SamplingParams
from tree_sitter import Language, Parser
from utils.eval_utils import is_identifier
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from utils.eval_utils import postprocess_code_lines, remove_comments
def process_single_item(example, output, language_name=None, ts_lib=None):
# if language_name != "arkts" or language_name != "typescript":
# language = Language(ts_lib, language_name)
# parser = Parser()
# parser.set_language(language)
# else:
parser = None
return remove_comments(postprocess_code_lines(example.left_context, output, parser, example.language))
class CustomDataset(Dataset):
"""
A dataset class for code generation.
Args:
args: Configuration parameters.
tokenizer: Tokenizer.
examples: A collection of examples.
retrieved_codeblocks: Retrieved code blocks.
api_contexts: API 知识库查询结果(可选)
"""
def __init__(self, args, tokenizer, examples, retrieved_codeblocks, generation=False, api_contexts=None):
self.args = args
self.tokenizer = tokenizer
self.examples = examples
self.retrieved_codeblocks = retrieved_codeblocks
self.generation = generation
self.api_contexts = api_contexts
def __len__(self):
return len(self.examples)
def construct_prompts(self, example, retrieved_codeblocks, api_context=None):
"""
构造 prompt
Args:
example: 样本
retrieved_codeblocks: 检索到的代码块
api_context: API 知识库查询结果(可选)
Returns:
prompt: 构造好的 prompt
"""
filter_codeblocks = []
for x in retrieved_codeblocks:
if x.file_path != "":
filter_codeblocks.append(x)
else:
break
# 处理 API 上下文
api_context_tokens = []
if api_context:
# 将 API 上下文编码,限制长度
api_context_tokens = self.tokenizer.encode(
api_context[:self.args.generator_max_crossfile_length * 10],
add_special_tokens=False
)[:self.args.generator_max_crossfile_length // 2] # API 上下文占用一半的 crossfile 长度
# 处理跨文件上下文
crossfile_context = "\n\n".join([str(retrieved_codeblock) for retrieved_codeblock in filter_codeblocks])
remaining_crossfile_length = self.args.generator_max_crossfile_length - len(api_context_tokens)
crossfile_context_tokens = self.tokenizer.encode(
crossfile_context[:remaining_crossfile_length * 10],
add_special_tokens=False
)[:remaining_crossfile_length]
# 文件路径上下文
path_context = f"\n\n# file path: {example.file_path}\n\n"
path_context_tokens = self.tokenizer.encode(path_context, add_special_tokens=False)
# 计算剩余空间用于文件内上下文
allowed_prompt_length = self.args.generator_max_context_length - (
len(api_context_tokens) + len(crossfile_context_tokens) + len(path_context_tokens) + 10
)
infile_context_tokens = self.tokenizer.encode(example.left_context, add_special_tokens=False)[-allowed_prompt_length:]
# 拼接:API 上下文 + 跨文件上下文 + 文件路径 + 文件内上下文
prompt = self.tokenizer.decode(
api_context_tokens + crossfile_context_tokens + path_context_tokens + infile_context_tokens
)
return prompt
def __getitem__(self, idx):
example = self.examples[idx]
retrieved_codeblocks = self.retrieved_codeblocks[idx]
# 获取对应的 API 上下文
api_context = None
if self.api_contexts is not None and idx < len(self.api_contexts):
api_context = self.api_contexts[idx]
prompt = self.construct_prompts(example, retrieved_codeblocks, api_context)
prompt_ids = self.tokenizer.encode(prompt)[-self.args.generator_max_context_length:]
if self.generation:
padding_length = self.args.generator_max_context_length - len(prompt_ids)
input_ids = [self.tokenizer.pad_token_id] * padding_length + prompt_ids
return torch.tensor(input_ids)
target_ids = self.tokenizer.encode(example.target_code, add_special_tokens=False)[:self.args.generator_max_generation_length]
input_ids = prompt_ids + target_ids
labels = [-100 for _ in prompt_ids] + target_ids
padding_length = self.args.generator_max_context_length + self.args.generator_max_generation_length - len(input_ids)
input_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids
labels = [-100] * padding_length + labels
return torch.tensor(input_ids), torch.tensor(labels)
class Model(nn.Module):
def __init__(self, generator_model_path, tokenizer, max_generation_length=64):
super(Model, self).__init__()
self.base_model = AutoModelForCausalLM.from_pretrained(generator_model_path, torch_dtype=torch.float16)
self.tokenizer = tokenizer
self.max_generation_length = max_generation_length
self.generator_model_path = generator_model_path
def forward(self, inputs=None, labels=None, lang='python', weighted_keywords=False, sample_number=0):
"""
Forward propagation method for calculating loss.
:param inputs: Input data.
:param labels: Label data.
:return: The average loss per sample.
"""
if inputs is None:
return None
if labels is not None:
logits = self.base_model(inputs, attention_mask=inputs.ne(self.tokenizer.pad_token_id))[0]
logits = logits[:, :-1]
labels = labels[:, 1:]
label_tokens = [self.tokenizer.convert_ids_to_tokens(id.item()) if id != -100 else '<pad>' for id in labels.reshape(-1)]
loss = torch.nn.functional.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100, reduction='none')
if weighted_keywords:
id_weight, first_token_weight = 3, 5
weights = torch.tensor([first_token_weight if i < 1 else id_weight if is_identifier(token, lang) and any(c.isalpha() or c.isdigit() or c == '_' for c in token) else 1 for i, token in enumerate(label_tokens)], dtype=torch.float).to(labels.device)
loss = loss * weights
loss_per_label = loss.reshape(labels.size(0), -1).sum(dim=1) / labels.ne(-100).sum(dim=1)
return loss_per_label
else:
if sample_number:
generated_ids = self.base_model.generate(inputs, attention_mask=inputs.ne(self.tokenizer.pad_token_id), max_length=inputs.size(1) + self.max_generation_length, pad_token_id=self.tokenizer.pad_token_id, do_sample=True, temperature=0.8, top_p=0.95)
else:
generated_ids = self.base_model.generate(inputs, attention_mask=inputs.ne(self.tokenizer.pad_token_id), max_length=inputs.size(1) + self.max_generation_length, pad_token_id=self.tokenizer.pad_token_id)
return generated_ids[:, inputs.size(1):]
class Generator:
"""
Code generator class.
Args:
args: Configuration parameters.
"""
def __init__(self, args):
self.tokenizer = AutoTokenizer.from_pretrained(args.generator_model_path)
self.tokenizer.model_max_length = 1e10
if self.tokenizer.pad_token_id == None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if not args.disable_generator:
self.model = Model(args.generator_model_path, self.tokenizer)
# 设置generator使用与retriever相同的GPU设备
if hasattr(args, 'retriever_device'):
available_gpus = torch.cuda.device_count()
if args.retriever_device < available_gpus:
device = f"cuda:{args.retriever_device}"
print(f"Generator设置使用GPU {args.retriever_device}(与Retriever相同)")
# 只使用指定的GPU,不使用DataParallel
self.model = self.model.to(device)
else:
print(f"警告:GPU {args.retriever_device} 不存在,Generator使用DataParallel模式")
self.model = torch.nn.DataParallel(self.model).cuda()
else:
# 默认行为:使用DataParallel
self.model = torch.nn.DataParallel(self.model).cuda()
self.model.eval()
self.args = args
def evaluate(self, examples, retrieved_codeblocks):
"""
Evaluates the generated code.
Args:
examples: A collection of examples.
retrieved_codeblocks: Retrieved code blocks.
Returns:
A list of loss values.
"""
losses = []
dataset = CustomDataset(self.args, self.tokenizer, examples, retrieved_codeblocks)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.args.generator_batch_size, num_workers=self.args.num_workers)
pbar = tqdm(dataloader, disable=not self.args.enable_tqdm)
with torch.no_grad():
for batch in pbar:
# 获取模型所在的设备
model_device = next(self.model.parameters()).device
inputs, labels = [x.to(model_device) for x in batch]
loss_per_label = self.model(inputs, labels, lang=examples[0].language, weighted_keywords=self.args.weighted_keywords)
losses.extend(loss_per_label.tolist())
current_ppl = np.exp(np.mean(losses))
pbar.set_description(f"Loss/PPL: {np.mean(losses):.3f}/{current_ppl:.3f}")
return losses
def generate(self, examples, retrieved_codeblocks, max_generation_length, sample_number=0, deduplicated=False, api_contexts=None):
"""
Generates code.
Args:
examples: A collection of examples.
retrieved_codeblocks: Retrieved code blocks.
max_generation_length: Maximum length of generation.
api_contexts: API 知识库查询结果(可选)
Returns:
A list of generated codes.
"""
generated_codes = []
if sample_number:
examples = [example for example in examples for _ in range(sample_number)]
retrieved_codeblocks = [codeblocks for codeblocks in retrieved_codeblocks for _ in range(sample_number)]
if api_contexts is not None:
api_contexts = [ctx for ctx in api_contexts for _ in range(sample_number)]
dataset = CustomDataset(self.args, self.tokenizer, examples, retrieved_codeblocks, generation=True, api_contexts=api_contexts)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.args.generator_batch_size, num_workers=self.args.num_workers)
if hasattr(self.model, "module"):
self.model.module.max_generation_length = max_generation_length
else:
self.model.max_generation_length = max_generation_length
pbar = tqdm(dataloader, disable=not self.args.enable_tqdm, desc="Generating")
with torch.no_grad():
for batch in pbar:
if batch is not None:
# 获取模型所在的设备
model_device = next(self.model.parameters()).device
output_temp = self.model(batch.to(model_device), sample_number=sample_number)
if output_temp is not None:
generated_codes.append(output_temp)
generated_codes = torch.cat(generated_codes, 0)
outputs = [self.tokenizer.decode(generated_id, skip_special_tokens=True) for generated_id in generated_codes]
if not sample_number:
return outputs
else:
if examples[0].language == 'python':
ts_lib = "utils/build/python-lang-parser.so"
else:
ts_lib = "utils/build/java-lang-parser.so"
language_name = examples[0].language
with mp.Pool(processes=64) as pool:
func = partial(process_single_item, language_name=language_name, ts_lib=ts_lib)
outputs = pool.starmap(func, zip(examples, outputs))
if not deduplicated:
return outputs
else:
deduplicated_outputs, counts_per_batch = [], []
for i in range(0, len(outputs), sample_number):
batch = outputs[i:i+sample_number]
set_temp, unique_batch = set(), list()
for item in batch:
if item not in set_temp and item.strip() != '':
set_temp.add(item)
unique_batch.append(item)
deduplicated_outputs.extend(unique_batch)
counts_per_batch.append(len(unique_batch))
return deduplicated_outputs, counts_per_batch
class vLLM_online_Generator:
def __init__(self, args):
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
self.client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
self.tokenizer = AutoTokenizer.from_pretrained(args.generator_model_path)
self.tokenizer.model_max_length = 1e10
if self.tokenizer.pad_token_id == None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.args = args
def generate(self, examples, retrieved_codeblocks, temperature, top_p, sample_number=0, deduplicated=False, api_contexts=None):
dataset = CustomDataset(self.args, self.tokenizer, examples, retrieved_codeblocks, generation=True, api_contexts=api_contexts)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=100000000, num_workers=self.args.num_workers)
sampling_params = {"temperature": temperature, "max_tokens": self.args.generator_max_generation_length, "top_p": top_p}
pbar = tqdm(dataloader, disable=not self.args.enable_tqdm, desc="Generating")
with torch.no_grad():
for batch in pbar:
prompts = [self.tokenizer.decode(x, skip_special_tokens=True) for x in batch]
if sample_number != 0:
outputs = self.client.completions.create(model=self.args.generator_model_path, prompt=prompts, n=sample_number, timeout=7200, **sampling_params, seed=123)
else:
outputs = self.client.completions.create(model=self.args.generator_model_path, prompt=prompts, timeout=7200, **sampling_params, seed=123)
if sample_number != 0:
outputs = [output.text for output in outputs.choices]
examples = [example for example in examples for _ in range(sample_number)]
language_name = examples[0].language
if language_name == 'python':
ts_lib = "utils/build/python-lang-parser.so"
else:
ts_lib = "utils/build/java-lang-parser.so"
with mp.Pool(processes=64) as pool:
func = partial(process_single_item, language_name=language_name, ts_lib=ts_lib)
outputs = pool.starmap(func, zip(examples, outputs))
if not deduplicated:
return outputs
else:
deduplicated_outputs, counts_per_batch = [], []
for i in range(0, len(outputs), sample_number):
batch = outputs[i:i + sample_number]
set_temp, unique_batch = set(), []
for item in batch:
if item not in set_temp and item.strip() != '':
set_temp.add(item)
unique_batch.append(item)
deduplicated_outputs.extend(unique_batch)
counts_per_batch.append(len(unique_batch))
return deduplicated_outputs, counts_per_batch
else:
return [output.text for output in outputs.choices]
class vLLM_offline_Generator:
def __init__(self, args):
# 检查并设置vLLM使用特定的GPU设备
import os
import torch
available_gpus = torch.cuda.device_count()
print(f"检测到 {available_gpus} 个GPU设备")
if hasattr(args, 'generator_device'):
if args.generator_device < available_gpus:
# 临时设置环境变量,但不影响全局
tensor_parallel_size = 1 # 单GPU
print(f"vLLM Generator设置使用GPU {args.generator_device}")
else:
print(f"警告:GPU {args.generator_device} 不存在,使用默认GPU 0")
args.generator_device = 0
# 使用tensor_parallel_size=1来限制vLLM只使用一个GPU
self.llm = LLM(
model=args.generator_model_path,
max_model_len=4096,
block_size=16, # 修改为16使其可以被16整除,从而使用FlashAttention
tensor_parallel_size=1,
gpu_memory_utilization=0.8,
dtype="bfloat16" # 使用 float16 而不是 bfloat16,兼容 PyTorch 2.0.1
)
self.tokenizer = self.llm.get_tokenizer()
self.tokenizer.model_max_length = 1e10
if self.tokenizer.pad_token_id == None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.args = args
def generate(self, examples, retrieved_codeblocks, temperature, top_p, sample_number=0, deduplicated=True, api_contexts=None):
dataset = CustomDataset(self.args, self.tokenizer, examples, retrieved_codeblocks, generation=True, api_contexts=api_contexts)
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset, sampler=sampler, batch_size=100000000, num_workers=self.args.num_workers)
if sample_number != 0:
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=self.args.generator_max_generation_length, n=sample_number)
else:
sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=self.args.generator_max_generation_length)
pbar = tqdm(dataloader, disable=not self.args.enable_tqdm, desc="Generating")
with torch.no_grad():
for batch in pbar:
prompts = [self.tokenizer.decode(x, skip_special_tokens=True) for x in batch]
outputs = self.llm.generate(prompts, sampling_params)
if sample_number != 0:
outputs = [completion.text for output in outputs for completion in output.outputs]
examples = [example for example in examples for _ in range(sample_number)]
if not deduplicated:
return outputs
else:
if examples[0].language == 'python':
ts_lib = "utils/build/python-lang-parser.so"
else:
ts_lib = "utils/build/java-lang-parser.so"
language_name = examples[0].language
with mp.Pool(processes=64) as pool:
func = partial(process_single_item, language_name=language_name, ts_lib=ts_lib)
outputs = pool.starmap(func, zip(examples, outputs))
deduplicated_outputs, counts_per_batch = [], []
for i in range(0, len(outputs), sample_number):
batch = outputs[i:i+sample_number]
set_temp = set()
unique_batch = []
for item in batch:
if item not in set_temp and item.strip() != '':
set_temp.add(item)
unique_batch.append(item)
deduplicated_outputs.extend(unique_batch)
counts_per_batch.append(len(unique_batch))
return deduplicated_outputs, counts_per_batch
else:
return [output.outputs[0].text for output in outputs]