From 04680d892afbb494d47f464ff7b5016297a2d44a Mon Sep 17 00:00:00 2001 From: Manas Sivakumar Date: Fri, 28 Jul 2023 10:30:18 +0530 Subject: [PATCH] Dynamic Model Loading --- src/conversation_terminator/remote/model.py | 4 ++++ src/conversation_terminator/remote/request.py | 3 ++- src/text_classification/grievance_recognition/local/model.py | 5 +++++ .../grievance_recognition/local/request.py | 3 ++- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/conversation_terminator/remote/model.py b/src/conversation_terminator/remote/model.py index 4a292b5..31dd1b4 100644 --- a/src/conversation_terminator/remote/model.py +++ b/src/conversation_terminator/remote/model.py @@ -14,6 +14,10 @@ def __new__(cls, context): return cls.instance async def inference(self, request: ModelRequest): + if request.model != 'NA': + model_name = str(request.model) + self.tokenizer = BertTokenizer.from_pretrained(model_name) + self.model = TFBertForSequenceClassification.from_pretrained(model_name) inputs = self.tokenizer(request.text,return_tensors="np", padding=True) outputs = self.model(inputs.input_ids, inputs.attention_mask) probabilities = tf.nn.sigmoid(outputs.logits) diff --git a/src/conversation_terminator/remote/request.py b/src/conversation_terminator/remote/request.py index d010f95..999c5ce 100644 --- a/src/conversation_terminator/remote/request.py +++ b/src/conversation_terminator/remote/request.py @@ -2,8 +2,9 @@ class ModelRequest(): - def __init__(self, text): + def __init__(self, text, model='NA'): self.text = text + self.model = model def to_json(self): return json.dumps(self, default=lambda o: o.__dict__, diff --git a/src/text_classification/grievance_recognition/local/model.py b/src/text_classification/grievance_recognition/local/model.py index 42e542e..13a1d2f 100644 --- a/src/text_classification/grievance_recognition/local/model.py +++ b/src/text_classification/grievance_recognition/local/model.py @@ -16,6 +16,11 @@ def __new__(cls, context): async def inference(self, request: ModelRequest): + if request.model != 'NA': + model_name = request.model + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForSequenceClassification.from_pretrained(model_name) + inputs = self.tokenizer(request.text, return_tensors="pt") inputs = {key: value.to(self.device) for key, value in inputs.items()} with torch.no_grad(): diff --git a/src/text_classification/grievance_recognition/local/request.py b/src/text_classification/grievance_recognition/local/request.py index 918b8c2..4144f18 100644 --- a/src/text_classification/grievance_recognition/local/request.py +++ b/src/text_classification/grievance_recognition/local/request.py @@ -3,8 +3,9 @@ class ModelRequest(): - def __init__(self, text): + def __init__(self, text, model='NA'): self.text = text + self.model = model def to_json(self): return json.dumps(self, default=lambda o: o.__dict__,