77from time import time as now
88from lingvodoc .queue .celery import celery
99from lingvodoc .cache .caching import TaskStatus , initialize_cache
10+ import tritonclient .grpc as grpcclient
11+ import numpy as np
1012
1113
1214# Choose model architecture
@@ -52,7 +54,7 @@ def __init__(self, vocab_size, embed_dim, max_len):
5254
5355 self .alpha = nn .Parameter (torch .tensor (0.7 ))
5456 self .beta = nn .Parameter (torch .tensor (0.3 ))
55- self .match_coef = nn .Parameter (torch .tensor (0.5 ) )
57+ self .match_coef = nn .Parameter (torch .tensor (0.8 ), requires_grad = False )
5658
5759 self .classifier = nn .Sequential (
5860 nn .Linear (4 * 128 , 256 ),
@@ -128,6 +130,8 @@ def process_batch(args):
128130 base_word_tensor = self ._process_text (input_word )
129131 base_tran_tensor = self ._process_text (input_tran )
130132
133+ triton_client = grpcclient .InferenceServerClient (url = "10.100.194.95:8001" )
134+
131135 for i , compare_list in enumerate (self .compare_lists ):
132136 if not compare_list :
133137 continue
@@ -144,11 +148,17 @@ def process_batch(args):
144148 'trans2' : torch .stack ([self ._process_text (t ) for t in compare_trans ])
145149 }
146150
151+ inputs = []
152+
153+ for field , tensor in batch .items ():
154+ inputs .append (grpcclient .InferInput (field , [batch_size , self .model .max_len ], "INT32" ))
155+ inputs [- 1 ].set_data_from_numpy (np .array (tensor , dtype = np .int32 ))
156+
147157 # Prediction
148158 with torch .no_grad ():
149- outputs = self .model (** batch )
150159 #probs = torch.sigmoid(outputs).squeeze()
151- probs = torch .sigmoid (outputs ).cpu ().numpy ().flatten ()
160+ outputs = triton_client .infer ("neuro_cognates" , inputs )
161+ probs = torch .sigmoid (torch .tensor ([out [0 ] for out in outputs .as_numpy ('output' )])).cpu ().numpy ().flatten ()
152162
153163 for idx , prob in enumerate (probs ):
154164 if prob .item () > self .truth_threshold :
@@ -198,15 +208,15 @@ def __init__(self,
198208 script_dir = os .path .dirname (script_path )
199209
200210 # Load model
201- self .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
202- checkpoint = torch .load (os .path .join (script_dir , 'best_model.pth' ), map_location = self .device )
211+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
212+ checkpoint = torch .load (os .path .join (script_dir , 'best_model.pth' )) # map_location=self.device)
203213 config = checkpoint .get ('config' , {})
204214
205215 self .model = DualPathSiamese (
206216 vocab_size = len (checkpoint ['char_to_index' ]),
207217 embed_dim = config .get ('embed_dim' , 128 ),
208218 max_len = config .get ('max_len' , 43 )
209- ).to (self .device )
219+ ) # .to(self.device)
210220
211221 self .model .load_state_dict (checkpoint ['model_state_dict' ])
212222 self .model .char_to_index = checkpoint ['char_to_index' ]
@@ -218,12 +228,14 @@ def __init__(self,
218228 if 'match_coef' in config :
219229 self .model .match_coef .data .fill_ (config ['match_coef' ])
220230
231+ #self.triton_client = grpcclient.InferenceServerClient(url="10.100.194.95:8001")
232+
221233 self .model .eval ()
222234
223235 def _process_text (self , text ):
224236 indices = [self .model .char_to_index .get (c , 1 ) for c in text .lower ()[:self .model .max_len ]]
225237 indices += [0 ] * (self .model .max_len - len (indices ))
226- return torch .tensor (indices , dtype = torch .long , device = self .device )
238+ return torch .tensor (indices , dtype = torch .long ) # device=self.device)
227239
228240 @staticmethod
229241 def split_items (items , input_links = None ):
@@ -244,6 +256,9 @@ def predict_cognates(self, word_pairs, task):
244256 results = []
245257 group_count = 0
246258 current_stage = 0
259+ progress = 0
260+ status = ""
261+ retry = 0
247262 result_link = ""
248263 input_len = len (word_pairs )
249264 compare_len = sum (map (len , self .compare_lists ))
@@ -252,7 +267,7 @@ def predict_cognates(self, word_pairs, task):
252267 stamp_file = os .path .join (self .storage ['path' ], 'lingvodoc_stamps' , str (task .id ))
253268
254269 def add_result (res ):
255- nonlocal current_stage , result_link , group_count
270+ nonlocal current_stage , result_link , group_count , progress , status , retry
256271 if res is None :
257272 return
258273
@@ -272,6 +287,8 @@ def add_result(res):
272287
273288 progress = 100 if finished else int (current_stage / input_len * 100 )
274289 status = "Finished" if finished else f"~ { days } d:{ hours } h:{ minutes } m left ~"
290+ if retry :
291+ status += f" ({ retry } failed)"
275292
276293 # Save results
277294 if current_stage % 10 == 0 or finished :
@@ -298,51 +315,59 @@ def add_result(res):
298315 args_list = zip ([self ] * input_len , input_words , input_trans , input_lex_ids , input_linked_groups )
299316
300317 def f (proc ):
301- task .set (None , 0 , f"Using { proc } processes..." )
318+ nonlocal retry
319+ task .set (None , 0 , f"Using { proc } process(es)..." )
302320 pool = Pool (proc )
303321 jobs = pool .imap_unordered (process_batch , args_list )
304322 pool .close ()
305323
306- try :
307- for _ in range (input_len ):
308-
324+ for _ in range (input_len ):
325+ try :
309326 if os .path .exists (stamp_file ):
310327 os .remove (stamp_file )
311328 raise InterruptedError ("Task stopped manually" )
312329
313330 else :
314- result = jobs .next (timeout = 600 )
331+ result = jobs .next (timeout = 60 )
315332 add_result (result )
316333
317- except RuntimeError :
318- msg = "No enough memory for the task"
334+ except RuntimeError :
335+ msg = "No enough memory for the task"
336+
337+ if proc > 1 :
338+ task .set (None , - 1 , msg )
339+ pool .terminate ()
340+ f (proc - 1 )
341+ return
319342
320- if proc > 1 :
321- task .set (None , - 1 , msg )
343+ else :
344+ raise InterruptedError (msg )
345+
346+ except InterruptedError as e :
347+ task .set (None , - 1 , str (e ), result_link )
322348 pool .terminate ()
323- f (proc - 1 )
324349 return
325350
326- else :
327- raise InterruptedError (msg )
328-
329- except InterruptedError as e :
330- task .set (None , - 1 , str (e ), result_link )
331- pool .terminate ()
332- return
351+ except Exception :
352+ if retry < 5 :
353+ retry += 1
354+ continue
355+ else :
356+ task .set (None , - 1 , "Server is busy now. Try again later." , result_link )
357+ pool .terminate ()
358+ return
333359
334360 try :
335361 set_start_method ('spawn' )
336- f (os .cpu_count () // 2 )
337- print ("Completed pool" )
362+ f (1 )
338363
339364 except Exception as e :
340365 print (e )
341366
342367 return results
343368
344369 def index (self , word_pairs , task ):
345- return NeuroCognates .predict_cognates .delay (
370+ return self .predict_cognates .delay (
346371 self ,
347372 word_pairs ,
348373 task )
0 commit comments