|
| 1 | +import torch |
| 2 | +from request import ModelRequest |
| 3 | +from InstructorEmbedding import INSTRUCTOR |
| 4 | +import wget |
| 5 | +import pandas as pd |
| 6 | +import os |
| 7 | + |
| 8 | +class Model(): |
| 9 | + def __new__(cls, context): |
| 10 | + cls.context = context |
| 11 | + if not hasattr(cls, 'instance'): |
| 12 | + cls.instance = super(Model, cls).__new__(cls) |
| 13 | + model_name = "hkunlp/instructor-large" |
| 14 | + cls.model = INSTRUCTOR(model_name) |
| 15 | + return cls.instance |
| 16 | + |
| 17 | + async def inference(self, request: ModelRequest): |
| 18 | + # Modify this function according to model requirements such that inputs and output remains the same |
| 19 | + corpus_instruction = "Represent the Wikipedia document for retrieval:" |
| 20 | + query_instruction = 'Represent the Wikipedia question for retrieving supporting documents: ' |
| 21 | + query = request.query |
| 22 | + |
| 23 | + if(query != None): |
| 24 | + # print('Query Encoding Process :-') |
| 25 | + query_embeddings = self.model.encode( |
| 26 | + [[query_instruction, query]], |
| 27 | + show_progress_bar=False, |
| 28 | + batch_size=32, |
| 29 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 30 | + ) |
| 31 | + return query_embeddings.tolist() |
| 32 | + |
| 33 | + if not request.df.empty: |
| 34 | + # print('Text corpus Encoding Process :-') |
| 35 | + data = request.df |
| 36 | + |
| 37 | + text_corpus = data.loc[:,'content'].to_list() |
| 38 | + corpus_embeddings = self.model.encode( |
| 39 | + [[corpus_instruction, text] for text in text_corpus], |
| 40 | + show_progress_bar=False, |
| 41 | + batch_size=32, |
| 42 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 43 | + ) |
| 44 | + data['embeddings'] = corpus_embeddings.tolist() |
| 45 | + csv_string = data.to_csv(index=False) |
| 46 | + |
| 47 | + return str(csv_string) |
0 commit comments