1+ import argparse
2+ from transformers import AutoTokenizer , AutoModelForCausalLM
3+ import torch
4+ import os
5+ from gammagl .utils .conversation import conv_templates , SeparatorStyle
6+ from gammagl .utils .gfm_utils import disable_torch_init , KeywordsStoppingCriteria
7+ from gammagl .utils .gfm_utils import DEFAULT_G_END_TOKEN , DEFAULT_G_START_TOKEN , DEFAULT_GRAPH_PATCH_TOKEN , DEFAULT_GRAPH_TOKEN , GRAPH_TOKEN_INDEX
8+ from transformers import CLIPVisionModel , CLIPImageProcessor , StoppingCriteria
9+ from gammagl .models .graphgpt import *
10+
11+ from torch_geometric .data import Data
12+ import json
13+ import copy
14+ from tqdm import tqdm
15+ import json
16+ import os .path as osp
17+ import ray
18+
19+ os .environ ['TL_BACKEND' ] = 'torch'
20+
21+ def load_graph (instruct_item , graph_data_path ):
22+ graph_data_all = torch .load (graph_data_path )
23+ graph_dict = instruct_item ['graph' ]
24+ graph_edge_index = torch .Tensor (copy .deepcopy (graph_dict ['edge_index' ])).long ()
25+ graph_node_list = copy .deepcopy (graph_dict ['node_list' ])
26+ target_node = copy .deepcopy (graph_dict ['node_idx' ])
27+ graph_type = copy .deepcopy (instruct_item ['id' ]).split ('_' )[0 ]
28+ graph_node_rep = graph_data_all [graph_type ].x [graph_node_list ] ##
29+
30+ cur_token_len = len (graph_node_rep ) # FIXME: 14 is hardcoded patch size
31+
32+ graph_ret = Data (graph_node = graph_node_rep , edge_index = graph_edge_index , target_node = torch .tensor ([target_node ]))
33+
34+ return {
35+ 'graph_data' : graph_ret ,
36+ 'graph_token_len' : cur_token_len
37+ }
38+
39+
40+ def load_prompting_file (file_path ):
41+ with open (file_path , 'r' ) as f :
42+ data = json .load (f )
43+ return data
44+
45+ # def prepare_query(instruct_item):
46+
47+
48+ def run_eval (args , num_gpus ):
49+ # split question file into num_gpus files
50+ prompt_file = load_prompting_file (args .prompting_file )
51+ args .end_id = min (args .end_id , len (prompt_file ))
52+ prompt_file = prompt_file [args .start_id :args .end_id ]
53+ chunk_size = len (prompt_file ) // num_gpus
54+ ans_handles = []
55+ split_list = list (range (args .start_id , args .end_id , chunk_size ))
56+ idx_list = list (range (0 , len (prompt_file ), chunk_size ))
57+ if len (split_list ) == num_gpus :
58+ split_list .append (args .end_id )
59+ idx_list .append (len (prompt_file ))
60+ elif len (split_list ) == num_gpus + 1 :
61+ split_list [- 1 ] = args .end_id
62+ idx_list [- 1 ] = len (prompt_file )
63+ else :
64+ raise ValueError ('error in the number of list' )
65+
66+ if osp .exists (args .output_res_path ) is False :
67+ os .mkdir (args .output_res_path )
68+
69+ for idx in range (len (idx_list ) - 1 ):
70+ start_idx = idx_list [idx ]
71+ end_idx = idx_list [idx + 1 ]
72+
73+ start_split = split_list [idx ]
74+ end_split = split_list [idx + 1 ]
75+ ans_handles .append (
76+ eval_model .remote (
77+ args , prompt_file [start_idx :end_idx ], start_split , end_split
78+ )
79+ )
80+
81+ ans_jsons = []
82+ for ans_handle in ans_handles :
83+ ans_jsons .extend (ray .get (ans_handle ))
84+
85+ # with open(args.output_res_path, "w") as ans_file:
86+ # for line in ans_jsons:
87+ # ans_file.write(json.dumps(line) + "\n")
88+
89+
90+ @ray .remote (num_gpus = 1 )
91+ @torch .inference_mode ()
92+ def eval_model (args , prompt_file , start_idx , end_idx ):
93+ # load prompting file
94+ # prompt_file = load_prompting_file(args.prompting_file)
95+
96+
97+ # Model
98+ disable_torch_init ()
99+ # model_name = os.path.expanduser(args.model_name)
100+ print ('start loading' )
101+ tokenizer = AutoTokenizer .from_pretrained (args .model_name )
102+ print ('finish loading' )
103+
104+ print ('start loading' )
105+ model = GraphLlamaForCausalLM .from_pretrained (args .model_name , torch_dtype = torch .float16 , use_cache = True , low_cpu_mem_usage = True ).cuda ()
106+ print ('finish loading' )
107+
108+ use_graph_start_end = getattr (model .config , "use_graph_start_end" , False )
109+ tokenizer .add_tokens ([DEFAULT_GRAPH_PATCH_TOKEN ], special_tokens = True )
110+ if use_graph_start_end :
111+ tokenizer .add_tokens ([DEFAULT_G_START_TOKEN , DEFAULT_G_END_TOKEN ], special_tokens = True )
112+
113+ graph_tower = model .get_model ().graph_tower
114+
115+ # TODO: add graph tower
116+ # if graph_tower.device.type == 'meta':
117+ # print('meta')
118+ clip_graph , args_graph = load_model_pretrained (CLIP , model .config .pretrain_graph_model_path )
119+ graph_tower = graph_transformer (args_graph )
120+ graph_tower = transfer_param_tograph (clip_graph , graph_tower )
121+
122+ model .get_model ().graph_tower = graph_tower .cuda ()
123+ # else:
124+ # print('other')
125+ # print(next(graph_tower.parameters()).dtype)
126+ graph_tower .to (device = 'cuda' , dtype = torch .float16 )
127+ graph_config = graph_tower .config
128+ graph_config .graph_patch_token = tokenizer .convert_tokens_to_ids ([DEFAULT_GRAPH_PATCH_TOKEN ])[0 ]
129+ graph_config .use_graph_start_end = use_graph_start_end
130+ if use_graph_start_end :
131+ graph_config .graph_start_token , graph_config .graph_end_token = tokenizer .convert_tokens_to_ids ([DEFAULT_G_START_TOKEN , DEFAULT_G_END_TOKEN ])
132+ # TODO: add graph token len
133+
134+ res_data = []
135+ print (f'total: { len (prompt_file )} ' )
136+ for idx , instruct_item in tqdm (enumerate (prompt_file )):
137+ # instruct_item = prompt_file[0]
138+ # if idx >= 3:
139+ # break
140+ graph_dict = load_graph (instruct_item , args .graph_data_path )
141+ graph_token_len = graph_dict ['graph_token_len' ]
142+ graph_data = graph_dict ['graph_data' ]
143+
144+ qs = instruct_item ["conversations" ][0 ]["value" ]
145+ # if use_graph_start_end:
146+ # qs = qs + '\n' + DEFAULT_G_START_TOKEN + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len + DEFAULT_G_END_TOKEN
147+ # else:
148+ # qs = qs + '\n' + DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len
149+
150+ replace_token = DEFAULT_GRAPH_PATCH_TOKEN * graph_token_len
151+ replace_token = DEFAULT_G_START_TOKEN + replace_token + DEFAULT_G_END_TOKEN
152+ qs = qs .replace (DEFAULT_GRAPH_TOKEN , replace_token )
153+
154+ # if "v1" in args.model_name.lower():
155+ # conv_mode = "graphchat_v1"
156+ # else:
157+ # raise ValueError('Don\'t support this model')
158+ conv_mode = "graphchat_v1"
159+
160+ if args .conv_mode is not None and conv_mode != args .conv_mode :
161+ print ('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}' .format (conv_mode , args .conv_mode , args .conv_mode ))
162+ else :
163+ args .conv_mode = conv_mode
164+
165+ conv = conv_templates [args .conv_mode ].copy ()
166+ conv .append_message (conv .roles [0 ], qs )
167+ conv .append_message (conv .roles [1 ], None )
168+ prompt = conv .get_prompt ()
169+ inputs = tokenizer ([prompt ])
170+
171+
172+
173+ input_ids = torch .as_tensor (inputs .input_ids ).cuda ()
174+
175+ stop_str = conv .sep if conv .sep_style != SeparatorStyle .TWO else conv .sep2
176+ keywords = [stop_str ]
177+ stopping_criteria = KeywordsStoppingCriteria (keywords , tokenizer , input_ids )
178+
179+ graph_data .graph_node = graph_data .graph_node .to (torch .float16 )
180+ # graph_data.edge_index = graph_data.edge_index.to(torch.float16)
181+
182+ with torch .inference_mode ():
183+ output_ids = model .generate (
184+ input_ids ,
185+ graph_data = graph_data .cuda (),
186+ do_sample = True ,
187+ temperature = 0.2 ,
188+ max_new_tokens = 1024 ,
189+ stopping_criteria = [stopping_criteria ])
190+
191+ input_token_len = input_ids .shape [1 ]
192+ n_diff_input_output = (input_ids != output_ids [:, :input_token_len ]).sum ().item ()
193+ if n_diff_input_output > 0 :
194+ print (f'[Warning] { n_diff_input_output } output_ids are not the same as the input_ids' )
195+ outputs = tokenizer .batch_decode (output_ids [:, input_token_len :], skip_special_tokens = True )[0 ]
196+ outputs = outputs .strip ()
197+ if outputs .endswith (stop_str ):
198+ outputs = outputs [:- len (stop_str )]
199+ outputs = outputs .strip ()
200+ # print(outputs)
201+
202+ res_data .append ({"id" : instruct_item ["id" ], "node_idx" : instruct_item ["graph" ]["node_idx" ], "res" : outputs }.copy ())
203+ with open (osp .join (args .output_res_path , 'arxiv_test_res_{}_{}.json' .format (start_idx , end_idx )), "w" ) as fout :
204+ json .dump (res_data , fout , indent = 4 )
205+ return res_data
206+ # with open(args.output_res_path, "w") as fout:
207+ # json.dump(res_data, fout, indent=4)
208+
209+ if __name__ == "__main__" :
210+ parser = argparse .ArgumentParser ()
211+ parser .add_argument ("--model-name" , type = str , default = "facebook/opt-350m" )
212+ # parser.add_argument("--image-file", type=str, required=True)
213+ # parser.add_argument("--query", type=str, required=True)
214+ parser .add_argument ("--prompting_file" , type = str , default = None )
215+ parser .add_argument ("--conv-mode" , type = str , default = None )
216+ parser .add_argument ("--graph_data_path" , type = str , default = None )
217+
218+ parser .add_argument ("--output_res_path" , type = str , default = None )
219+ parser .add_argument ("--num_gpus" , type = int , default = 4 )
220+
221+ parser .add_argument ("--start_id" , type = int , default = 0 )
222+ parser .add_argument ("--end_id" , type = int , default = 20567 )
223+
224+ args = parser .parse_args ()
225+
226+ # eval_model(args)
227+
228+ ray .init ()
229+ run_eval (args , args .num_gpus )
230+
231+
232+ # protobuf 4.22.3
0 commit comments