diff --git a/deeppavlov/core/common/requirements_registry.json b/deeppavlov/core/common/requirements_registry.json index d65eba771e..530d0de9fc 100644 --- a/deeppavlov/core/common/requirements_registry.json +++ b/deeppavlov/core/common/requirements_registry.json @@ -5,7 +5,9 @@ ], "entity_linker": [ "{DEEPPAVLOV_PATH}/requirements/hdt.txt", - "{DEEPPAVLOV_PATH}/requirements/rapidfuzz.txt" + "{DEEPPAVLOV_PATH}/requirements/rapidfuzz.txt", + "{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt", + "{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt" ], "fasttext": [ "{DEEPPAVLOV_PATH}/requirements/fasttext.txt" @@ -58,6 +60,7 @@ "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], "ru_adj_to_noun": [ + "{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt", "{DEEPPAVLOV_PATH}/requirements/udapi.txt" ], "russian_words_vocab": [ @@ -147,7 +150,9 @@ "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], "tree_to_sparql": [ - "{DEEPPAVLOV_PATH}/requirements/udapi.txt" + "{DEEPPAVLOV_PATH}/requirements/udapi.txt", + "{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt", + "{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt" ], "typos_custom_reader": [ "{DEEPPAVLOV_PATH}/requirements/lxml.txt" diff --git a/deeppavlov/core/models/torch_model.py b/deeppavlov/core/models/torch_model.py index af6edb409e..234abf1a17 100644 --- a/deeppavlov/core/models/torch_model.py +++ b/deeppavlov/core/models/torch_model.py @@ -101,6 +101,31 @@ def __init__(self, device: str = "gpu", self.model.eval() log.debug(f"Model was successfully initialized! Model summary:\n {self.model}") + def get_optimizer(self): + """ + Initialize optimizer from bnb. Resort to pytorch if already initialized + """ + try: + # Import BNB opt + import bitsandbytes as bnb + if 'AdamW' in self.optimizer_name: + log.info('No weight decay supported in bitsandbytes yet') + opt_name = self.optimizer_name.replace('AdamW','Adam') + else: + opt_name = self.optimizer_name + #if self.optimizer_name[-4:] != '8bit': # backwards compatibility + # opt_name = opt_name + '8bit' + log.info(f'Using bitsandbytes optimizer {opt_name}') + optimizer = getattr(bnb.optim, opt_name)( + self.model.parameters(), **self.optimizer_parameters) + except Exception as e: + print(e) + breakpoint() + log.info('Not imported 8bit optimizer - resorting to torch optimizer') + optimizer = getattr(torch.optim, self.optimizer_name)( + self.model.parameters(), **self.optimizer_parameters) + return optimizer + def init_from_opt(self, model_func: str) -> None: """Initialize from scratch `self.model` with the architecture built in `model_func` method of this class along with `self.optimizer` as `self.optimizer_name` from `torch.optim` and parameters @@ -115,8 +140,7 @@ def init_from_opt(self, model_func: str) -> None: """ if callable(model_func): self.model = model_func(**self.opt).to(self.device) - self.optimizer = getattr(torch.optim, self.optimizer_name)( - self.model.parameters(), **self.optimizer_parameters) + self.optimizer = self.get_optimizer() if self.lr_scheduler_name: self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( self.optimizer, **self.lr_scheduler_parameters) diff --git a/deeppavlov/models/classifiers/torch_nets.py b/deeppavlov/models/classifiers/torch_nets.py index 1f912f26f3..9133c79b46 100644 --- a/deeppavlov/models/classifiers/torch_nets.py +++ b/deeppavlov/models/classifiers/torch_nets.py @@ -13,10 +13,13 @@ # limitations under the License. from typing import List, Union, Optional +from logging import getLogger import torch import torch.nn as nn +log = getLogger(__name__) + class ShallowAndWideCnn(nn.Module): def __init__(self, n_classes: int, embedding_size: int, kernel_sizes_cnn: List[int], @@ -27,7 +30,12 @@ def __init__(self, n_classes: int, embedding_size: int, kernel_sizes_cnn: List[i self.kernel_sizes_cnn = kernel_sizes_cnn if not embedded_tokens and vocab_size: - self.embedding = nn.Embedding(vocab_size, embedding_size) + try: + import bitsandbytes as bnb + self.embedding = bnb.nn.StableEmbedding(vocab_size, embedding_size) + except: + log.info('Not imported 8bit optimizer - resorting to torch optimizer') + self.embedding = nn.Embedding(vocab_size, embedding_size) if isinstance(filters_cnn, int): filters_cnn = len(kernel_sizes_cnn) * [filters_cnn] diff --git a/deeppavlov/models/entity_extraction/entity_linking.py b/deeppavlov/models/entity_extraction/entity_linking.py index 4d60fa1470..b91e1ea412 100644 --- a/deeppavlov/models/entity_extraction/entity_linking.py +++ b/deeppavlov/models/entity_extraction/entity_linking.py @@ -14,19 +14,19 @@ import re import sqlite3 +from collections import defaultdict from logging import getLogger from typing import List, Dict, Tuple, Union, Any -from collections import defaultdict -import pymorphy2 +import spacy from hdt import HDTDocument from nltk.corpus import stopwords from rapidfuzz import fuzz +from deeppavlov.core.commands.utils import expand_path from deeppavlov.core.common.registry import register from deeppavlov.core.models.component import Component from deeppavlov.core.models.serializable import Serializable -from deeppavlov.core.commands.utils import expand_path log = getLogger(__name__) @@ -75,7 +75,6 @@ def __init__( **kwargs: """ super().__init__(save_path=None, load_path=load_path) - self.morph = pymorphy2.MorphAnalyzer() self.lemmatize = lemmatize self.entities_database_filename = entities_database_filename self.num_entities_for_bert_ranking = num_entities_for_bert_ranking @@ -86,8 +85,10 @@ def __init__( self.lang = f"@{lang}" if self.lang == "@en": self.stopwords = set(stopwords.words("english")) + self.nlp = spacy.load("en_core_web_sm") elif self.lang == "@ru": self.stopwords = set(stopwords.words("russian")) + self.nlp = spacy.load("ru_core_news_sm") self.use_descriptions = use_descriptions self.use_connections = use_connections self.max_paragraph_len = max_paragraph_len @@ -198,7 +199,7 @@ def link_entities( ): cand_ent_scores = [] if len(entity_substr) > 1: - entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split] + entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split] cand_ent_init = self.find_exact_match(entity_substr, tag) if not cand_ent_init or entity_substr_split != entity_substr_split_lemm: cand_ent_init = self.find_fuzzy_match(entity_substr_split, tag) @@ -297,28 +298,23 @@ def find_exact_match(self, entity_substr, tag): entity_substr_split = entity_substr_split[1:] entities_and_ids = self.find_title(entity_substr) cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag) - if self.lang == "@ru": - entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split] - entity_substr_lemm = " ".join(entity_substr_split_lemm) - if entity_substr_lemm != entity_substr: - entities_and_ids = self.find_title(entity_substr_lemm) - if entities_and_ids: - cand_ent_init = self.process_cand_ent( - cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag - ) + + entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split] + entity_substr_lemm = " ".join(entity_substr_split_lemm) + if entity_substr_lemm != entity_substr: + entities_and_ids = self.find_title(entity_substr_lemm) + if entities_and_ids: + cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag) return cand_ent_init def find_fuzzy_match(self, entity_substr_split, tag): - if self.lang == "@ru": - entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split] - else: - entity_substr_split_lemm = entity_substr_split + entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split] cand_ent_init = defaultdict(set) for word in entity_substr_split: part_entities_and_ids = self.find_title(word) cand_ent_init = self.process_cand_ent(cand_ent_init, part_entities_and_ids, entity_substr_split, tag) if self.lang == "@ru": - word_lemm = self.morph.parse(word)[0].normal_form + word_lemm = self.nlp(word)[0].lemma_ if word != word_lemm: part_entities_and_ids = self.find_title(word_lemm) cand_ent_init = self.process_cand_ent( @@ -329,11 +325,6 @@ def find_fuzzy_match(self, entity_substr_split, tag): ) return cand_ent_init - def morph_parse(self, word): - morph_parse_tok = self.morph.parse(word)[0] - normal_form = morph_parse_tok.normal_form - return normal_form - def calc_substr_score(self, cand_entity_title, entity_substr_split): label_tokens = cand_entity_title.split() cnt = 0.0 diff --git a/deeppavlov/models/kbqa/tree_to_sparql.py b/deeppavlov/models/kbqa/tree_to_sparql.py index b5ff26c44b..d406ce7368 100644 --- a/deeppavlov/models/kbqa/tree_to_sparql.py +++ b/deeppavlov/models/kbqa/tree_to_sparql.py @@ -19,7 +19,7 @@ from typing import Any, List, Tuple, Dict, Union import numpy as np -import pymorphy2 +import spacy from navec import Navec from scipy.sparse import csr_matrix from slovnet import Syntax @@ -66,11 +66,10 @@ def __init__(self, freq_dict_filename: str, candidate_nouns: int = 10, **kwargs) self.adj_set = set([word for word, freq in pos_freq_dict["a"]]) self.nouns = [noun[0] for noun in self.nouns_with_freq] self.matrix = self.make_sparse_matrix(self.nouns).transpose() - self.morph = pymorphy2.MorphAnalyzer() + self.nlp = spacy.load("ru_core_news_sm") def search(self, word: str): - word = self.morph.parse(word)[0] - word = word.normal_form + word = self.nlp(word)[0].lemma_ if word in self.adj_set: q_matrix = self.make_sparse_matrix([word]) scores = q_matrix * self.matrix @@ -190,6 +189,7 @@ def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun: self.begin_tokens = {"начинать", "начать"} self.end_tokens = {"завершить", "завершать", "закончить"} self.ranking_tokens = {"самый"} + self.nlp = spacy.load("ru_core_news_sm") elif self.lang == "eng": self.q_pronouns = {"what", "who", "how", "when", "where", "which"} self.how_many = "how many" @@ -199,12 +199,12 @@ def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun: self.begin_tokens = set() self.end_tokens = set() self.ranking_tokens = set() + self.nlp = spacy.load("en_core_web_sm") else: raise ValueError(f"unsupported language {lang}") self.sparql_queries_filename = expand_path(sparql_queries_filename) self.template_queries = read_json(self.sparql_queries_filename) self.adj_to_noun = adj_to_noun - self.morph = pymorphy2.MorphAnalyzer() def __call__(self, syntax_tree_batch: List[str], positions_batch: List[List[List[int]]]) -> Tuple[ @@ -274,7 +274,7 @@ def __call__(self, syntax_tree_batch: List[str], self.root_entity = True temporal_order = self.find_first_last(new_root) - new_root_nf = self.morph.parse(new_root.form)[0].normal_form + new_root_nf = self.nlp(new_root.form)[0].lemma_ if new_root_nf in self.begin_tokens or new_root_nf in self.end_tokens: temporal_order = new_root_nf ranking_tokens = self.find_ranking_tokens(new_root) @@ -288,7 +288,7 @@ def __call__(self, syntax_tree_batch: List[str], question = [] for node in tree.descendants: if node.ord in ranking_tokens or node.form.lower() in self.q_pronouns: - question.append(self.morph.parse(node.form)[0].normal_form) + question.append(self.nlp(node.form)[0].lemma_) else: question.append(node.form) question = ' '.join(question) @@ -496,9 +496,9 @@ def find_first_last(self, node: Node) -> str: for node in nodes: node_desc = defaultdict(set) for elem in node.children: - parsed_elem = self.morph.parse(elem.form.lower())[0].inflect({"masc", "sing", "nomn"}) + parsed_elem = self.nlp(elem.form.lower())[0].lemma_ if parsed_elem is not None: - node_desc[elem.deprel].add(parsed_elem.word) + node_desc[elem.deprel].add(parsed_elem) else: node_desc[elem.deprel].add(elem.form) if "amod" in node_desc.keys() and "nmod" in node_desc.keys() and \ @@ -511,7 +511,7 @@ def find_first_last(self, node: Node) -> str: def find_ranking_tokens(self, node: Node) -> list: ranking_tokens = [] for elem in node.descendants: - if self.morph.parse(elem.form)[0].normal_form in self.ranking_tokens: + if self.nlp(elem.form)[0].lemma_ in self.ranking_tokens: ranking_tokens.append(elem.ord) ranking_tokens.append(elem.parent.ord) return ranking_tokens diff --git a/deeppavlov/models/kbqa/type_define.py b/deeppavlov/models/kbqa/type_define.py index 7e9ab41be5..1ccdd9b388 100644 --- a/deeppavlov/models/kbqa/type_define.py +++ b/deeppavlov/models/kbqa/type_define.py @@ -15,7 +15,6 @@ import pickle from typing import List -import pymorphy2 import spacy from nltk.corpus import stopwords @@ -43,7 +42,6 @@ def __init__(self, lang: str, types_filename: str, types_sets_filename: str, self.types_filename = str(expand_path(types_filename)) self.types_sets_filename = str(expand_path(types_sets_filename)) self.num_types_to_return = num_types_to_return - self.morph = pymorphy2.MorphAnalyzer() if self.lang == "@en": self.stopwords = set(stopwords.words("english")) self.nlp = spacy.load("en_core_web_sm") @@ -102,7 +100,7 @@ def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[st types_substr_tokens = types_substr.split() types_substr_tokens = [tok for tok in types_substr_tokens if tok not in self.stopwords] if self.lang == "@ru": - types_substr_tokens = [self.morph.parse(tok)[0].normal_form for tok in types_substr_tokens] + types_substr_tokens = [self.nlp(tok)[0].lemma_ for tok in types_substr_tokens] types_substr_tokens = set(types_substr_tokens) types_scores = [] for entity in self.types_dict: diff --git a/deeppavlov/models/torch_bert/torch_bert_ranker.py b/deeppavlov/models/torch_bert/torch_bert_ranker.py index 261e4bd03e..a72215c608 100644 --- a/deeppavlov/models/torch_bert/torch_bert_ranker.py +++ b/deeppavlov/models/torch_bert/torch_bert_ranker.py @@ -202,8 +202,7 @@ def load(self, fname=None): self.model.to(self.device) - self.optimizer = getattr(torch.optim, self.optimizer_name)( - self.model.parameters(), **self.optimizer_parameters) + self.optimizer = self.get_optimizer() if self.lr_scheduler_name is not None: self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( self.optimizer, **self.lr_scheduler_parameters) diff --git a/deeppavlov/models/torch_bert/torch_transformers_classifier.py b/deeppavlov/models/torch_bert/torch_transformers_classifier.py index d2449dafc2..8218ac168c 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_classifier.py +++ b/deeppavlov/models/torch_bert/torch_transformers_classifier.py @@ -255,8 +255,7 @@ def load(self, fname=None): self.model.to(self.device) - self.optimizer = getattr(torch.optim, self.optimizer_name)( - self.model.parameters(), **self.optimizer_parameters) + self.optimizer = self.get_optimizer() if self.lr_scheduler_name is not None: self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( self.optimizer, **self.lr_scheduler_parameters) diff --git a/deeppavlov/models/torch_bert/torch_transformers_multiplechoice.py b/deeppavlov/models/torch_bert/torch_transformers_multiplechoice.py index c989715d10..45132429e5 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_multiplechoice.py +++ b/deeppavlov/models/torch_bert/torch_transformers_multiplechoice.py @@ -182,8 +182,7 @@ def load(self, fname = None): self.model.to(self.device) - self.optimizer = getattr(torch.optim, self.optimizer_name)( - self.model.parameters(), **self.optimizer_parameters) + self.optimizer = self.get_optimizer() if self.lr_scheduler_name is not None: self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( self.optimizer, **self.lr_scheduler_parameters) diff --git a/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py b/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py index fc78e4f32e..f30e968f95 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py +++ b/deeppavlov/models/torch_bert/torch_transformers_sequence_tagger.py @@ -294,8 +294,7 @@ def load(self, fname=None): if self.use_crf: self.crf = CRF(self.n_classes).to(self.device) - self.optimizer = getattr(torch.optim, self.optimizer_name)( - self.model.parameters(), **self.optimizer_parameters) + self.optimizer = self.get_optimizer() if self.lr_scheduler_name is not None: self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( self.optimizer, **self.lr_scheduler_parameters) diff --git a/deeppavlov/models/torch_bert/torch_transformers_squad.py b/deeppavlov/models/torch_bert/torch_transformers_squad.py index 83aee7bc7e..31b20843e9 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_squad.py +++ b/deeppavlov/models/torch_bert/torch_transformers_squad.py @@ -292,8 +292,7 @@ def load(self, fname=None): self.model = torch.nn.DataParallel(self.model) self.model.to(self.device) - self.optimizer = getattr(torch.optim, self.optimizer_name)( - self.model.parameters(), **self.optimizer_parameters) + self.optimizer = self.get_optimizer() if self.lr_scheduler_name is not None: self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( self.optimizer, **self.lr_scheduler_parameters) diff --git a/requirements.txt b/requirements.txt index 92488368eb..c8721026bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ scikit-learn>=0.24,<1.1.0 scipy<1.9.0 tqdm>=4.42.0,<4.65.0 uvicorn>=0.13.0,<0.19.0 +bitsandbytes-cuda113==0.25.0 diff --git a/tests/test_quick_start.py b/tests/test_quick_start.py index bd4f73ef41..a78682f2bb 100644 --- a/tests/test_quick_start.py +++ b/tests/test_quick_start.py @@ -256,7 +256,7 @@ ("kbqa/kbqa_cq_ru.json", "kbqa", ('IP',)): [ ("Кто такой Оксимирон?", ("российский рэп-исполнитель",)), - ("Чем питаются коалы?", ("Лист",)), + ("Кто написал «Евгений Онегин»?", ("Александр Сергеевич Пушкин",)), ("абв", ("Not Found",)) ] },