diff --git a/src/conversation_terminator/remote/model.py b/src/conversation_terminator/remote/model.py index 4a292b5..31dd1b4 100644 --- a/src/conversation_terminator/remote/model.py +++ b/src/conversation_terminator/remote/model.py @@ -14,6 +14,10 @@ def __new__(cls, context): return cls.instance async def inference(self, request: ModelRequest): + if request.model != 'NA': + model_name = str(request.model) + self.tokenizer = BertTokenizer.from_pretrained(model_name) + self.model = TFBertForSequenceClassification.from_pretrained(model_name) inputs = self.tokenizer(request.text,return_tensors="np", padding=True) outputs = self.model(inputs.input_ids, inputs.attention_mask) probabilities = tf.nn.sigmoid(outputs.logits) diff --git a/src/conversation_terminator/remote/request.py b/src/conversation_terminator/remote/request.py index d010f95..999c5ce 100644 --- a/src/conversation_terminator/remote/request.py +++ b/src/conversation_terminator/remote/request.py @@ -2,8 +2,9 @@ class ModelRequest(): - def __init__(self, text): + def __init__(self, text, model='NA'): self.text = text + self.model = model def to_json(self): return json.dumps(self, default=lambda o: o.__dict__, diff --git a/src/text_classification/grievance_recognition/local/model.py b/src/text_classification/grievance_recognition/local/model.py index 42e542e..13a1d2f 100644 --- a/src/text_classification/grievance_recognition/local/model.py +++ b/src/text_classification/grievance_recognition/local/model.py @@ -16,6 +16,11 @@ def __new__(cls, context): async def inference(self, request: ModelRequest): + if request.model != 'NA': + model_name = request.model + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForSequenceClassification.from_pretrained(model_name) + inputs = self.tokenizer(request.text, return_tensors="pt") inputs = {key: value.to(self.device) for key, value in inputs.items()} with torch.no_grad(): diff --git a/src/text_classification/grievance_recognition/local/request.py b/src/text_classification/grievance_recognition/local/request.py index 918b8c2..4144f18 100644 --- a/src/text_classification/grievance_recognition/local/request.py +++ b/src/text_classification/grievance_recognition/local/request.py @@ -3,8 +3,9 @@ class ModelRequest(): - def __init__(self, text): + def __init__(self, text, model='NA'): self.text = text + self.model = model def to_json(self): return json.dumps(self, default=lambda o: o.__dict__,