Skip to content

Commit 5cb2cce

Browse files
authored
Model:Implement GraphGPT and LLaGA (#232)
* add llaga * add graphgpt
1 parent 3042e6d commit 5cb2cce

File tree

14 files changed

+3445
-0
lines changed

14 files changed

+3445
-0
lines changed

examples/graphgpt/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# GraphGPT: Graph Instruction Tuning for Large Language Models
2+
* Paper link: http://arxiv.org/abs/2310.13023
3+
* Author's code repo: https://github.com/HKUDS/GraphGPT
4+
5+
# How to Run
6+
7+
* First, follow the original repo to install all required packages;
8+
9+
* Then download all required datasets and pretrained checkpoints, and fill their path into corresponding values in eval.sh
10+
11+
# Dataset Statics
12+
| Dataset | # Nodes | # Edges | # Classes |
13+
| :-------: | :-------: | :------: | :------: |
14+
| Cora | 25,120 | 182,280 | 70 |
15+
| PubMed | 19,717 | 44,338 | 3 |
16+
| ogb-arxiv | 169,343 | 1,166,243 | 40 |
17+
18+
# Files Description
19+
* graphgpt_trainer.py: the trainer of graphgpt, inference stage
20+
* graphgpt_eval.py: run this to evaluate
21+
22+
# Results
23+
```bash
24+
# run inference
25+
TL_BACKEND="torch" nohup bash examples/graphgpt/eval.sh > log/test_graphgpt.out &
26+
# run evaluation
27+
python examples/graphgpt/graphgpt_eval.py --dataset cora
28+
```
29+
| Dataset | Paper | Our(torch) |
30+
| :-------: | :-------: | :------: |
31+
| Cora | 0.1501 | 0.1451 |

examples/graphgpt/eval.sh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
export PYTHONPATH=$(dirname $(dirname $(realpath $0))):$PYTHONPATH
2+
# to fill in the following path to extract projector for the second tuning stage!
3+
output_model=/local/yy3/graphgpt/GraphGPT-7B-mix-all # path to the pre-trained model checkpoint
4+
datapath=/local/yy3/graphgpt/data/eval/cora_test_instruct_std.json # path to the instruction datset
5+
graph_data_path=/local/yy3/graphgpt/data/graph_data_all.pt # path to the graph data
6+
res_path=./output_stage_2_cora_nc # path to save the results
7+
start_id=0
8+
end_id=20000 # total number of instructions to test
9+
num_gpus=1
10+
11+
export CUDA_VISIBLE_DEVICES=2 # specify the GPU id
12+
13+
python ./examples/graphgpt/graphgpt_trainer.py --model-name ${output_model} --prompting_file ${datapath} --graph_data_path ${graph_data_path} --output_res_path ${res_path} --start_id ${start_id} --end_id ${end_id} --num_gpus ${num_gpus}

examples/graphgpt/graphgpt_eval.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import json
2+
import os.path as osp
3+
import os
4+
import torch as th
5+
import re
6+
import pandas as pd
7+
from tqdm import tqdm
8+
from sklearn.metrics import classification_report
9+
10+
import argparse
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument('--dataset', type=str, default='arxiv')
13+
args = parser.parse_args()
14+
15+
label_to_idx = {
16+
"cora":{"databases, object oriented": 29, "operating systems, memory management": 59, "data structures algorithms and theory, quantum computing": 24, "artificial intelligence, planning": 13, "artificial intelligence, knowledge representation": 4, "artificial intelligence, data mining": 1, "artificial intelligence, vision and pattern recognition": 17, "artificial intelligence, machine learning, case-based": 5, "artificial intelligence, agents": 0, "artificial intelligence, machine learning, probabilistic methods": 8, "encryption and compression, security": 36, "operating systems, distributed": 57, "human computer interaction, interface design": 46, "artificial intelligence, machine learning, genetic algorithms": 6, "human computer interaction, graphics and virtual reality": 45, "artificial intelligence, machine learning, rule learning": 10, "programming, functional": 63, "programming, object oriented": 67, "encryption and compression, encryption": 35, "databases, performance": 30, "networking, protocols": 54, "data structures algorithms and theory, randomized": 25, "data structures algorithms and theory, formal languages": 20, "data structures algorithms and theory, parallel": 23, "programming, software development": 69, "programming, compiler design": 61, "artificial intelligence, machine learning, theory": 11, "artificial intelligence, machine learning, neural networks": 7, "programming, logic": 66, "databases, relational": 32, "information retrieval, retrieval": 52, "programming, debugging": 62, "networking, wireless": 56, "artificial intelligence, theorem proving": 16, "databases, temporal": 33, "encryption and compression, compression": 34, "information retrieval, filtering": 51, "data structures algorithms and theory, computational complexity": 18, "programming, garbage collection": 64, "artificial intelligence, machine learning, reinforcement learning": 9, "human computer interaction, multimedia": 47, "hardware and architecture, vlsi": 43, "artificial intelligence, nlp": 12, "hardware and architecture, microprogramming": 42, "operating systems, fault tolerance": 58, "programming, java": 65, "operating systems, realtime": 60, "human computer interaction, cooperative": 44, "artificial intelligence, speech": 15, "databases, deductive": 28, "artificial intelligence, robotics": 14, "data structures algorithms and theory, logic": 22, "networking, routing": 55, "hardware and architecture, logic design": 40, "hardware and architecture, distributed architectures": 37, "data structures algorithms and theory, hashing": 21, "programming, semantics": 68, "artificial intelligence, games and search": 3, "databases, concurrency": 27, "data structures algorithms and theory, sorting": 26, "human computer interaction, wearable computers": 48, "information retrieval, digital library": 49, "artificial intelligence, expert systems": 2, "information retrieval, extraction": 50, "data structures algorithms and theory, computational geometry": 19, "databases, query evaluation": 31, "networking, internet": 53, "hardware and architecture, memory structures": 41, "hardware and architecture, high performance computing": 38, "hardware and architecture, input output and storage": 39},
17+
"pubmed":{"Experimentally induced diabetes": 0, "Type 2 diabetes": 2, "Type 1 diabetes": 1}
18+
}
19+
20+
21+
22+
data_list = []
23+
folder = 'output_stage_2_{}_nc'.format(args.dataset)
24+
for filename in os.listdir(folder):
25+
if filename.endswith('.json'):
26+
file_path = os.path.join(folder, filename)
27+
with open(file_path, 'r') as f:
28+
data = json.load(f)
29+
data_list.extend(data)
30+
31+
print(data_list[1])
32+
33+
graph_data = th.load('/local/yy3/graphgpt/data/graph_data_all.pt')[args.dataset]
34+
labels = graph_data.y
35+
36+
def cal_map():
37+
label_dict = {}
38+
if args.dataset == "arxiv":
39+
df = pd.read_csv(os.path.expanduser('~/datasets/OGB/ogbn_arxiv/mapping/labelidx2arxivcategeory.csv.gz'), compression='gzip')
40+
for index, line in df.iterrows():
41+
lb = line['arxiv category'].split(' ')[-1]
42+
lb_new = 'cs.' + lb.upper()
43+
label_dict[lb_new] = line['label idx']
44+
else:
45+
label_dict = label_to_idx[args.dataset]
46+
return label_dict
47+
48+
class_map = cal_map()
49+
50+
inverse_class_map = {}
51+
for lb, lb_id in class_map.items():
52+
inverse_class_map[lb_id] = lb
53+
54+
55+
pattern = r"cs\.[A-Z]{2}"
56+
57+
58+
topk = 3
59+
60+
correct = 0
61+
total = len(data_list)
62+
63+
trues = []
64+
preds = []
65+
66+
for instruct_item in tqdm(data_list):
67+
nid = instruct_item['node_idx']
68+
gpt_res = instruct_item['res']
69+
70+
71+
true_y = labels[nid]
72+
73+
pred_y = []
74+
if args.dataset == "arxiv":
75+
matches = list(set(re.findall(pattern, gpt_res))) # pred
76+
sorted_matches = sorted(matches, key=lambda x: gpt_res.index(x))
77+
for m in sorted_matches:
78+
try:
79+
pred_y.append(class_map[m])
80+
except:
81+
pass
82+
try:
83+
# print(sorted_matches)
84+
preds.append(pred_y[0])
85+
except:
86+
preds.append(-1)
87+
else:
88+
for lb, lb_id in class_map.items():
89+
if lb in gpt_res:
90+
pred_y.append(lb_id)
91+
try:
92+
# print(sorted_matches)
93+
preds.append(pred_y[0])
94+
except:
95+
preds.append(-1)
96+
trues.append(true_y.item())
97+
res_tmp = 1 if true_y in pred_y[:topk] else 0
98+
correct = correct + 1 if true_y in pred_y[:topk] else correct
99+
100+
acc = correct / total
101+
102+
print("Accuracy:", acc)
103+
104+
report = classification_report(trues, preds, digits=6)
105+
106+
print(report)
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import argparse
2+
from transformers import AutoTokenizer, AutoModelForCausalLM
3+
import torch
4+
import os
5+
from gammagl.utils.conversation import conv_templates, SeparatorStyle
6+
from gammagl.utils.gfm_utils import disable_torch_init, KeywordsStoppingCriteria
7+
from gammagl.utils.gfm_utils import DEFAULT_G_END_TOKEN, DEFAULT_G_START_TOKEN, DEFAULT_GRAPH_PATCH_TOKEN, DEFAULT_GRAPH_TOKEN, GRAPH_TOKEN_INDEX
8+
from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
9+
from gammagl.models.graphgpt import *
10+
11+
from torch_geometric.data import Data
12+
import json
13+
import copy
14+
from tqdm import tqdm
15+
import json
16+
import os.path as osp
17+
import ray
18+
19+
os.environ['TL_BACKEND'] = 'torch'
20+
21+
def load_graph(instruct_item, graph_data_path):
22+
graph_data_all = torch.load(graph_data_path)
23+
graph_dict = instruct_item['graph']
24+
graph_edge_index = torch.Tensor(copy.deepcopy(graph_dict['edge_index'])).long()
25+
graph_node_list = copy.deepcopy(graph_dict['node_list'])
26+
target_node = copy.deepcopy(graph_dict['node_idx'])
27+
graph_type = copy.deepcopy(instruct_item['id']).split('_')[0]
28+
graph_node_rep = graph_data_all[graph_type].x[graph_node_list] ##
29+
30+
cur_token_len = len(graph_node_rep) # FIXME: 14 is hardcoded patch size
31+
32+
graph_ret = Data(graph_node = graph_node_rep, edge_index=graph_edge_index, target_node = torch.tensor([target_node]))
33+
34+
return {
35+
'graph_data': graph_ret,
36+
'graph_token_len': cur_token_len
37+
}
38+
39+
40+
def load_prompting_file(file_path):
41+
with open(file_path, 'r') as f:
42+
data = json.load(f)
43+
return data
44+
45+
# def prepare_query(instruct_item):
46+
47+
48+
def run_eval(args, num_gpus):
49+
# split question file into num_gpus files
50+
prompt_file = load_prompting_file(args.prompting_file)
51+
args.end_id = min(args.end_id, len(prompt_file))
52+
prompt_file = prompt_file[args.start_id:args.end_id]
53+
chunk_size = len(prompt_file) // num_gpus
54+
ans_handles = []
55+
split_list = list(range(args.start_id, args.end_id, chunk_size))
56+
idx_list = list(range(0, len(prompt_file), chunk_size))
57+
if len(split_list) == num_gpus:
58+
split_list.append(args.end_id)
59+
idx_list.append(len(prompt_file))
60+
elif len(split_list) == num_gpus + 1:
61+
split_list[-1] = args.end_id
62+
idx_list[-1] = len(prompt_file)
63+
else:
64+
raise ValueError('error in the number of list')
65+
66+
if osp.exists(args.output_res_path) is False:
67+
os.mkdir(args.output_res_path)
68+
69+
for idx in range(len(idx_list) - 1):
70+
start_idx = idx_list[idx]
71+
end_idx = idx_list[idx + 1]
72+
73+
start_split = split_list[idx]
74+
end_split = split_list[idx + 1]
75+
ans_handles.append(
76+
eval_model.remote(
77+
args, prompt_file[start_idx:end_idx], start_split, end_split
78+
)
79+
)
80+
81+
ans_jsons = []
82+
for ans_handle in ans_handles:
83+
ans_jsons.extend(ray.get(ans_handle))
84+
85+
# with open(args.output_res_path, "w") as ans_file:
86+
# for line in ans_jsons:
87+
# ans_file.write(json.dumps(line) + "\n")
88+
89+
90+
@ray.remote(num_gpus=1)
91+
@torch.inference_mode()
92+
def eval_model(args, prompt_file, start_idx, end_idx):
93+
# load prompting file
94+
# prompt_file = load_prompting_file(args.prompting_file)
95+
96+
97+
# Model
98+
disable_torch_init()
99+
# model_name = os.path.expanduser(args.model_name)
100+
print('start loading')
101+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
102+
print('finish loading')
103+
104+
print('start loading')
105+
model = GraphLlamaForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16, use_cache=True, low_cpu_mem_usage=True).cuda()
106+
print('finish loading')
107+
108+
use_graph_start_end = getattr(model.config, "use_graph_start_end", False)
109+
tokenizer.add_tokens([DEFAULT_GRAPH_PATCH_TOKEN], special_tokens=True)
110+
if use_graph_start_end:
111+
tokenizer.add_tokens([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN], special_tokens=True)
112+
113+
graph_tower = model.get_model().graph_tower
114+
115+
# TODO: add graph tower
116+
# if graph_tower.device.type == 'meta':
117+
# print('meta')
118+
clip_graph, args_graph= load_model_pretrained(CLIP, model.config.pretrain_graph_model_path)
119+
graph_tower = graph_transformer(args_graph)
120+
graph_tower = transfer_param_tograph(clip_graph, graph_tower)
121+
122+
model.get_model().graph_tower = graph_tower.cuda()
123+
# else:
124+
# print('other')
125+
# print(next(graph_tower.parameters()).dtype)
126+
graph_tower.to(device='cuda', dtype=torch.float16)
127+
graph_config = graph_tower.config
128+
graph_config.graph_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_GRAPH_PATCH_TOKEN])[0]
129+
graph_config.use_graph_start_end = use_graph_start_end
130+
if use_graph_start_end:
131+
graph_config.graph_start_token, graph_config.graph_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_G_START_TOKEN, DEFAULT_G_END_TOKEN])
132+
# TODO: add graph token len
133+
134+
res_data = []
135+
print(f'total: {len(prompt_file)}')
136+
for idx, instruct_item in tqdm(enumerate(prompt_file)):
137+
# instruct_item = prompt_file[0]
138+
# if idx >= 3:
139+
# break
140+
graph_dict = load_graph(instruct_item, args.graph_data_path)
141+
graph_token_len = graph_dict['graph_token_len']
142+
graph_data = graph_dict['graph_data']
143+
144+
qs = instruct_item["conversations"][0]["value"]
145+
# if use_graph_start_end:
146+
# qs = qs + '\n' + DEFAULT_G_START_TOKEN + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + DEFAULT_G_END_TOKEN
147+
# else:
148+
# qs = qs + '\n' + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len
149+
150+
replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len
151+
replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN
152+
qs = qs.replace(DEFAULT_GRAPH_TOKEN, replace_token)
153+
154+
# if "v1" in args.model_name.lower():
155+
# conv_mode = "graphchat_v1"
156+
# else:
157+
# raise ValueError('Don\'t support this model')
158+
conv_mode = "graphchat_v1"
159+
160+
if args.conv_mode is not None and conv_mode != args.conv_mode:
161+
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
162+
else:
163+
args.conv_mode = conv_mode
164+
165+
conv = conv_templates[args.conv_mode].copy()
166+
conv.append_message(conv.roles[0], qs)
167+
conv.append_message(conv.roles[1], None)
168+
prompt = conv.get_prompt()
169+
inputs = tokenizer([prompt])
170+
171+
172+
173+
input_ids = torch.as_tensor(inputs.input_ids).cuda()
174+
175+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
176+
keywords = [stop_str]
177+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
178+
179+
graph_data.graph_node = graph_data.graph_node.to(torch.float16)
180+
# graph_data.edge_index = graph_data.edge_index.to(torch.float16)
181+
182+
with torch.inference_mode():
183+
output_ids = model.generate(
184+
input_ids,
185+
graph_data=graph_data.cuda(),
186+
do_sample=True,
187+
temperature=0.2,
188+
max_new_tokens=1024,
189+
stopping_criteria=[stopping_criteria])
190+
191+
input_token_len = input_ids.shape[1]
192+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
193+
if n_diff_input_output > 0:
194+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
195+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
196+
outputs = outputs.strip()
197+
if outputs.endswith(stop_str):
198+
outputs = outputs[:-len(stop_str)]
199+
outputs = outputs.strip()
200+
# print(outputs)
201+
202+
res_data.append({"id": instruct_item["id"], "node_idx": instruct_item["graph"]["node_idx"], "res": outputs}.copy())
203+
with open(osp.join(args.output_res_path, 'arxiv_test_res_{}_{}.json'.format(start_idx, end_idx)), "w") as fout:
204+
json.dump(res_data, fout, indent=4)
205+
return res_data
206+
# with open(args.output_res_path, "w") as fout:
207+
# json.dump(res_data, fout, indent=4)
208+
209+
if __name__ == "__main__":
210+
parser = argparse.ArgumentParser()
211+
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
212+
# parser.add_argument("--image-file", type=str, required=True)
213+
# parser.add_argument("--query", type=str, required=True)
214+
parser.add_argument("--prompting_file", type=str, default=None)
215+
parser.add_argument("--conv-mode", type=str, default=None)
216+
parser.add_argument("--graph_data_path", type=str, default=None)
217+
218+
parser.add_argument("--output_res_path", type=str, default=None)
219+
parser.add_argument("--num_gpus", type=int, default=4)
220+
221+
parser.add_argument("--start_id", type=int, default=0)
222+
parser.add_argument("--end_id", type=int, default=20567)
223+
224+
args = parser.parse_args()
225+
226+
# eval_model(args)
227+
228+
ray.init()
229+
run_eval(args, args.num_gpus)
230+
231+
232+
# protobuf 4.22.3

0 commit comments

Comments
 (0)