Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"justMyCode": false,
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
"args": "${command:pickArgs}"
},
{
"name": "ConvAssist",
Expand All @@ -41,7 +42,9 @@
"name": "Python Debugger: Python File",
"type": "debugpy",
"request": "launch",
"program": "${file}"
"program": "${file}",
"justMyCode": false,

},
{
"name": "Python: Debug Tests",
Expand Down
1 change: 0 additions & 1 deletion 3rd_party_resources/third_party_programs.md

This file was deleted.

8 changes: 7 additions & 1 deletion 3rd_party_resources/utils/database_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
from typing import List

from tqdm import tqdm
Expand Down Expand Up @@ -51,7 +52,6 @@ def configure():
#flag to clean the database

parser.add_argument(
"-c",
"--clean",
action="store_true",
help="Whether to clean the database"
Expand Down Expand Up @@ -89,14 +89,20 @@ def main(argv=None):
if response.lower() != 'y':
print("Exiting...")
return
else:
print("Cleaning database...")
os.remove(args.database)

phrases = []

with open(args.input_file) as f:
for line in f:
phrases.append(line.strip())


with NGramUtil(args.database, args.cardinality, args.lowercase, args.normalize) as ngramutil:
ngramutil.create_update_ngram_tables()

threads = []
for i in range(args.cardinality):
p = Thread(target=insertngrambycardinality, args=(ngramutil, phrases, i + 1))
Expand Down
5 changes: 5 additions & 0 deletions 3rd_party_resources/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import pandas as pd

df = pd.read_parquet('train-00000-of-00001-aaf72b9960b78228.parquet')
df.to_csv('data.csv', index=False)

2 changes: 1 addition & 1 deletion convassist/ConvAssist.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def initialize(
self.set_predictors()

self.predictor_activator = PredictorActivator(
self.config, self.predictor_registry, self.context_tracker, self.logger
self.name, self.config, self.predictor_registry, self.context_tracker #, self.logger
)
self.predictor_activator.combination_policy = "meritocracy"

Expand Down
4 changes: 2 additions & 2 deletions convassist/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __init__(
self._sentence_transformer_model: str = "" # Path
self._sentences_db: str = "" # Path
self._spellingdatabase: str = "" # Path
self._startsents: str = "" # Path
self._startwords: str = "" # Path
self._startsents: str = "start_sentences.txt" # Filename
self._startwords: str = "start_words.txt" # Filename
self._static_resources_path: str = ""
self._stopwords: str = "" # Path
self._test_generalsentenceprediction: bool = False
Expand Down
117 changes: 15 additions & 102 deletions convassist/predictor/sentence_completion_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,60 +12,20 @@
import numpy
import torch
import transformers
import tqdm
from nltk import word_tokenize
from nltk.stem.porter import PorterStemmer
from sentence_transformers import SentenceTransformer

from convassist.predictor.predictor import Predictor
from convassist.predictor.utilities.nlp import NLP
from convassist.predictor.utilities.prediction import Prediction, Suggestion
from convassist.utilities.databaseutils.sqllite_dbconnector import (
SQLiteDatabaseConnector,
)
from convassist.predictor.utilities.svo_util import SVOUtil


class SentenceCompletionPredictor(Predictor):
"""
SentenceCompletionPredictor is a class that provides functionality for sentence completion prediction using a pre-trained language model and a corpus of sentences.

Methods:
configure(self):
Configures the predictor by loading the necessary models, embeddings, and indexes.
load_model(self) -> None:
Loads the pre-trained language model for sentence generation.
retrieve(self):
Property to get the retrieve attribute.
retrieve(self, value):
Property to set the retrieve attribute.
_set_seed(self, seed):
Sets the random seed for reproducibility.
_read_personalized_toxic_words(self):
Reads personalized allowed toxic words from a file.
_extract_svo(self, sent):
Extracts subject-verb-object (SVO) from a given sentence.
_ngram_to_string(self, ngram):
Converts an n-gram to a string.
_filter_text(self, text):
Filters the text to check for blacklisted words.
_textInCorpus(self, text):
Checks if the given text is in the corpus and returns the similarity score.
_retrieve_fromDataset(self, context):
Retrieves sentences from the dataset that match the given context.
_checkRepetition(self, text):
Checks for repetitive bigrams and trigrams in the text.
_generate(self, context: str, num_gen: int) -> Prediction:
Generates sentence completions for the given context.
model_loaded(self):
Property to check if the model is loaded.
predict(self, max_partial_prediction_size: int, filter: Optional[str] = None):
Predicts sentence completions based on the given context.
learn(self, change_tokens):
Learns from the given change tokens by adding them to the database.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._retrieveaac = None
self._model_loaded = False
"""

def __init__(self, *args, **kwargs):
import os
Expand All @@ -89,7 +49,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def configure(self):
self.nlp = NLP().get_nlp()

# check if saved torch model exists
self.load_model()
Expand Down Expand Up @@ -117,25 +76,7 @@ def configure(self):

self.personalized_allowed_toxicwords = self._read_personalized_toxic_words()

self.OBJECT_DEPS = {
"dobj",
"pobj",
"dative",
"attr",
"oprd",
"npadvmod",
"amod",
"acomp",
"advmod",
}
self.SUBJECT_DEPS = {"nsubj", "nsubjpass", "csubj", "agent", "expl"}

# tags that define wether the word is wh-
self.WH_WORDS = {"WP", "WP$", "WRB"}
self.stopwords = []
stoplist = open(self.stopwordsFile).readlines()
for s in stoplist:
self.stopwords.append(s.strip())
self.svo_util = SVOUtil(self.stopwordsFile)

if not Path.is_file(Path(self.embedding_cache_path)):
self.corpus_embeddings = self.embedder.encode(
Expand All @@ -162,10 +103,15 @@ def configure(self):

# Create the HNSWLIB index
self.logger.debug("Start creating HNSWLIB index")
self.index.init_index(max_elements=20000, ef_construction=400, M=64)
self.logger.debug(f"len(corpus_sentences) = {len(self.corpus_sentences)}")
max_elements = 20000 if len(self.corpus_sentences) > 20000 else len(self.corpus_sentences)
self.index.init_index(max_elements=max_elements, ef_construction=400, M=64)

# Then we train the index to find a suitable clustering
self.index.add_items(self.corpus_embeddings, list(range(len(self.corpus_embeddings))))
with tqdm.tqdm(total=max_elements) as pbar:
for idx, emb in enumerate(self.corpus_embeddings[:max_elements]):
self.index.add_items(emb, idx)
pbar.update(1)

self.logger.debug(f"Saving index to: {self.index_path}")
self.index.save_index(self.index_path)
Expand Down Expand Up @@ -237,41 +183,10 @@ def _read_personalized_toxic_words(self):
self.logger.debug(f"UPDATED TOXIC WORDS = {self.personalized_allowed_toxicwords}")
return self.personalized_allowed_toxicwords

def _extract_svo(self, sent):
doc = self.nlp(sent)
sub = []
at = []
ve = []
imp_tokens = []
for token in doc:
# is this a verb?
if token.pos_ == "VERB":
ve.append(token.text)
if (
token.text.lower() not in self.stopwords
and token.text.lower() not in imp_tokens
):
imp_tokens.append(token.text.lower())
# is this the object?
if token.dep_ in self.OBJECT_DEPS or token.head.dep_ in self.OBJECT_DEPS:
at.append(token.text)
if (
token.text.lower() not in self.stopwords
and token.text.lower() not in imp_tokens
):
imp_tokens.append(token.text.lower())
# is this the subject?
if token.dep_ in self.SUBJECT_DEPS or token.head.dep_ in self.SUBJECT_DEPS:
sub.append(token.text)
if (
token.text.lower() not in self.stopwords
and token.text.lower() not in imp_tokens
):
imp_tokens.append(token.text.lower())
return imp_tokens

def _ngram_to_string(self, ngram):
"|".join(ngram)


# def _ngram_to_string(self, ngram):
# "|".join(ngram)

def _filter_text(self, text):
res = False
Expand Down Expand Up @@ -448,7 +363,7 @@ def _generate(self, context: str, num_gen: int, predictions:Prediction) -> Predi

if remainderTextForFilter != "":
if clean_sentence not in allsent:
imp_tokens = self._extract_svo(clean_sentence)
imp_tokens = self.svo_util.extract_svo(clean_sentence)
imp_tokens_reminder = []
# get important tokens only of the generated completion
for imp in imp_tokens:
Expand Down Expand Up @@ -654,8 +569,6 @@ def learn(self, change_tokens):
{"sentences": self.corpus_sentences, "embeddings": self.corpus_embeddings},
self.embedding_cache_path,
)
# with open(self.embedding_cache_path, "wb") as fOut:
# pickle.dump({'sentences': self.corpus_sentences, 'embeddings': self.corpus_embeddings}, fOut)

# Then we train the index to find a suitable clustering
self.logger.debug(
Expand Down
105 changes: 4 additions & 101 deletions convassist/predictor/smoothed_ngram_predictor/canned_word_predictor.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: GPL-3.0-or-later

import json
import os

from tqdm import tqdm

from convassist.predictor.utilities.nlp import NLP
from convassist.predictor.utilities.prediction import Prediction, Suggestion
from convassist.utilities.ngram.ngram_map import NgramMap
from convassist.utilities.ngram.ngramutil import NGramUtil

from convassist.predictor.smoothed_ngram_predictor.smoothed_ngram_predictor import SmoothedNgramPredictor
from convassist.predictor.utilities.svo_util import SVOUtil


class CannedWordPredictor(SmoothedNgramPredictor):
"""
CannedWordPredictor is a specialized predictor that extends the SmoothedNgramPredictor.
It is designed to handle canned responses using natural language processing (NLP) techniques.

Methods:
configure():
Configures the predictor by loading the NLP model and initializing constants and stopwords.

extract_svo(sent: str) -> str:
Extracts significant subject-verb-object (SVO) tokens from a given sentence.

recreate_database():
Recreates the sentence and n-gram databases by adding new phrases and removing outdated ones.

"""

def configure(self):
# load the natural language processing model
self.nlp = NLP().get_nlp()

# object and subject constants
self.OBJECT_DEPS = {
"dobj",
"pobj",
"dative",
"attr",
"oprd",
"npadvmod",
"amod",
"acomp",
"advmod",
}
self.SUBJECT_DEPS = {"nsubj", "nsubjpass", "csubj", "agent", "expl"}
# tags that define wether the word is wh-
self.WH_WORDS = {"WP", "WP$", "WRB"}
self.stopwordsList = []

with open(self.stopwordsFile) as f:
self.stopwordsList = f.read().splitlines()

# strip each word in stopwordsList
self.stopwordsList = [word.strip() for word in self.stopwordsList]
self.svo_utils = SVOUtil(self.stopwordsFile)

super().configure()

def extract_svo(self, sent) -> str:
return " ".join(self.svo_utils.extract_svo(sent))

def extract_svo(self, sent):
doc = self.nlp(sent)
sub = []
at = []
ve = []
imp_tokens = []
for token in doc:
# is this a verb?
if token.pos_ == "VERB":
ve.append(token.text)
if (
token.text.lower() not in self.stopwordsList
and token.text.lower() not in imp_tokens
):
imp_tokens.append(token.text.lower())
# is this the object?
if token.dep_ in self.OBJECT_DEPS or token.head.dep_ in self.OBJECT_DEPS:
at.append(token.text)
if (
token.text.lower() not in self.stopwordsList
and token.text.lower() not in imp_tokens
):
imp_tokens.append(token.text.lower())
# is this the subject?
if token.dep_ in self.SUBJECT_DEPS or token.head.dep_ in self.SUBJECT_DEPS:
sub.append(token.text)
if (
token.text.lower() not in self.stopwordsList
and token.text.lower() not in imp_tokens
):
imp_tokens.append(token.text.lower())
return " ".join(imp_tokens).strip().lower()

def recreate_database(self):
"""
Expand Down Expand Up @@ -126,31 +56,4 @@ def recreate_database(self):
def startwords(self):
return os.path.join(self._personalized_resources_path, self._startwords)

# TODO: Refactor this class and general_word since this is the same code.
def predict(self, max_partial_prediction_size: int, filter):
"""
Predicts the next word based on the context tracker and the n-gram model.
"""
sentence_predictions = Prediction() # Not used in this predictor
word_predictions = Prediction()

actual_tokens, _ = self.context_tracker.get_tokens(self.cardinality)

if actual_tokens == 0:
self.logger.warning(
f"No tokens in the context tracker. Getting {max_partial_prediction_size} most frequent start words..."
)

with open(self.startwords) as f:
self.precomputed_StartWords = json.load(f)

for w, prob in list(self.precomputed_StartWords.items())[:max_partial_prediction_size]:
word_predictions.add_suggestion(Suggestion(w, prob, self.predictor_name))

if len(word_predictions) == 0:
self.logger.error("Error getting most frequent start words.")

return sentence_predictions, word_predictions

else:
return super().predict(max_partial_prediction_size, filter)
Loading
Loading