Skip to content

Commit 85be9b5

Browse files
committed
Computing neuro cognates on gpu
1 parent 13c70bc commit 85be9b5

File tree

3 files changed

+55
-29
lines changed

3 files changed

+55
-29
lines changed

lingvodoc/schema/gql_cognate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5706,7 +5706,7 @@ def neuro_cognate_statistics(
57065706
if not input_len or not compare_len:
57075707
triumph = False
57085708
message = "No input words or words to compare is received"
5709-
elif compare_len > 10 ** 4:
5709+
elif compare_len > 30000:
57105710
triumph = False
57115711
message = f"Too many words to compare: {compare_len}"
57125712
else:

lingvodoc/utils/neuro_cognates/app.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from time import time as now
88
from lingvodoc.queue.celery import celery
99
from lingvodoc.cache.caching import TaskStatus, initialize_cache
10+
import tritonclient.grpc as grpcclient
11+
import numpy as np
1012

1113

1214
# Choose model architecture
@@ -52,7 +54,7 @@ def __init__(self, vocab_size, embed_dim, max_len):
5254

5355
self.alpha = nn.Parameter(torch.tensor(0.7))
5456
self.beta = nn.Parameter(torch.tensor(0.3))
55-
self.match_coef = nn.Parameter(torch.tensor(0.5))
57+
self.match_coef = nn.Parameter(torch.tensor(0.8), requires_grad=False)
5658

5759
self.classifier = nn.Sequential(
5860
nn.Linear(4 * 128, 256),
@@ -128,6 +130,8 @@ def process_batch(args):
128130
base_word_tensor = self._process_text(input_word)
129131
base_tran_tensor = self._process_text(input_tran)
130132

133+
triton_client = grpcclient.InferenceServerClient(url="10.100.194.95:8001")
134+
131135
for i, compare_list in enumerate(self.compare_lists):
132136
if not compare_list:
133137
continue
@@ -144,11 +148,17 @@ def process_batch(args):
144148
'trans2': torch.stack([self._process_text(t) for t in compare_trans])
145149
}
146150

151+
inputs = []
152+
153+
for field, tensor in batch.items():
154+
inputs.append(grpcclient.InferInput(field, [batch_size, self.model.max_len], "INT32"))
155+
inputs[-1].set_data_from_numpy(np.array(tensor, dtype=np.int32))
156+
147157
# Prediction
148158
with torch.no_grad():
149-
outputs = self.model(**batch)
150159
#probs = torch.sigmoid(outputs).squeeze()
151-
probs = torch.sigmoid(outputs).cpu().numpy().flatten()
160+
outputs = triton_client.infer("neuro_cognates", inputs)
161+
probs = torch.sigmoid(torch.tensor([out[0] for out in outputs.as_numpy('output')])).cpu().numpy().flatten()
152162

153163
for idx, prob in enumerate(probs):
154164
if prob.item() > self.truth_threshold:
@@ -198,15 +208,15 @@ def __init__(self,
198208
script_dir = os.path.dirname(script_path)
199209

200210
# Load model
201-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
202-
checkpoint = torch.load(os.path.join(script_dir, 'best_model.pth'), map_location=self.device)
211+
#self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
212+
checkpoint = torch.load(os.path.join(script_dir, 'best_model.pth')) # map_location=self.device)
203213
config = checkpoint.get('config', {})
204214

205215
self.model = DualPathSiamese(
206216
vocab_size=len(checkpoint['char_to_index']),
207217
embed_dim=config.get('embed_dim', 128),
208218
max_len=config.get('max_len', 43)
209-
).to(self.device)
219+
) #.to(self.device)
210220

211221
self.model.load_state_dict(checkpoint['model_state_dict'])
212222
self.model.char_to_index = checkpoint['char_to_index']
@@ -218,12 +228,14 @@ def __init__(self,
218228
if 'match_coef' in config:
219229
self.model.match_coef.data.fill_(config['match_coef'])
220230

231+
#self.triton_client = grpcclient.InferenceServerClient(url="10.100.194.95:8001")
232+
221233
self.model.eval()
222234

223235
def _process_text(self, text):
224236
indices = [self.model.char_to_index.get(c, 1) for c in text.lower()[:self.model.max_len]]
225237
indices += [0] * (self.model.max_len - len(indices))
226-
return torch.tensor(indices, dtype=torch.long, device=self.device)
238+
return torch.tensor(indices, dtype=torch.long) # device=self.device)
227239

228240
@staticmethod
229241
def split_items(items, input_links=None):
@@ -244,6 +256,9 @@ def predict_cognates(self, word_pairs, task):
244256
results = []
245257
group_count = 0
246258
current_stage = 0
259+
progress = 0
260+
status = ""
261+
retry = 0
247262
result_link = ""
248263
input_len = len(word_pairs)
249264
compare_len = sum(map(len, self.compare_lists))
@@ -252,7 +267,7 @@ def predict_cognates(self, word_pairs, task):
252267
stamp_file = os.path.join(self.storage['path'], 'lingvodoc_stamps', str(task.id))
253268

254269
def add_result(res):
255-
nonlocal current_stage, result_link, group_count
270+
nonlocal current_stage, result_link, group_count, progress, status, retry
256271
if res is None:
257272
return
258273

@@ -272,6 +287,8 @@ def add_result(res):
272287

273288
progress = 100 if finished else int(current_stage / input_len * 100)
274289
status = "Finished" if finished else f"~ {days}d:{hours}h:{minutes}m left ~"
290+
if retry:
291+
status += f" ({retry} failed)"
275292

276293
# Save results
277294
if current_stage % 10 == 0 or finished:
@@ -298,51 +315,59 @@ def add_result(res):
298315
args_list = zip([self] * input_len, input_words, input_trans, input_lex_ids, input_linked_groups)
299316

300317
def f(proc):
301-
task.set(None, 0, f"Using {proc} processes...")
318+
nonlocal retry
319+
task.set(None, 0, f"Using {proc} process(es)...")
302320
pool = Pool(proc)
303321
jobs = pool.imap_unordered(process_batch, args_list)
304322
pool.close()
305323

306-
try:
307-
for _ in range(input_len):
308-
324+
for _ in range(input_len):
325+
try:
309326
if os.path.exists(stamp_file):
310327
os.remove(stamp_file)
311328
raise InterruptedError("Task stopped manually")
312329

313330
else:
314-
result = jobs.next(timeout=600)
331+
result = jobs.next(timeout=60)
315332
add_result(result)
316333

317-
except RuntimeError:
318-
msg = "No enough memory for the task"
334+
except RuntimeError:
335+
msg = "No enough memory for the task"
336+
337+
if proc > 1:
338+
task.set(None, -1, msg)
339+
pool.terminate()
340+
f(proc - 1)
341+
return
319342

320-
if proc > 1:
321-
task.set(None, -1, msg)
343+
else:
344+
raise InterruptedError(msg)
345+
346+
except InterruptedError as e:
347+
task.set(None, -1, str(e), result_link)
322348
pool.terminate()
323-
f(proc - 1)
324349
return
325350

326-
else:
327-
raise InterruptedError(msg)
328-
329-
except InterruptedError as e:
330-
task.set(None, -1, str(e), result_link)
331-
pool.terminate()
332-
return
351+
except Exception:
352+
if retry < 5:
353+
retry += 1
354+
continue
355+
else:
356+
task.set(None, -1, "Server is busy now. Try again later.", result_link)
357+
pool.terminate()
358+
return
333359

334360
try:
335361
set_start_method('spawn')
336-
f(os.cpu_count() // 2)
337-
print("Completed pool")
362+
f(1)
338363

339364
except Exception as e:
340365
print(e)
341366

342367
return results
343368

344369
def index(self, word_pairs, task):
345-
return NeuroCognates.predict_cognates.delay(
370+
return self.predict_cognates.delay(
346371
self,
347372
word_pairs,
348373
task)

server-requirements-1.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ torch==2.6.0
104104
tqdm==4.51.0
105105
transaction==1.6.1
106106
translationstring==1.3
107+
tritonclient[grpc]==2.51
107108
typepy==0.6.6
108109
typing==3.6.2
109110
uniparser-erzya<1.2

0 commit comments

Comments
 (0)