1717
1818
1919class DIETClassifierWrapper :
20+ """Wrapper for DIETClassifier."""
2021 def __init__ (self , config : Union [Dict [str , Dict [str , Any ]], str ]):
22+ """
23+ Create wrapper with configuration.
24+
25+ :param config: config in dictionary format or path to config file (.yml)
26+ """
2127 if isinstance (config , str ):
2228 try :
2329 f = open (config , "r" )
@@ -58,6 +64,11 @@ def __init__(self, config: Union[Dict[str, Dict[str, Any]], str]):
5864 self .softmax = torch .nn .Softmax (dim = - 1 )
5965
6066 def tokenize (self , sentences ) -> Tuple [Dict [str , Any ], List [List [Tuple [int , int ]]]]:
67+ """
68+ Tokenize sentences using tokenizer.
69+ :param sentences: list of sentences
70+ :return: tuple(tokenized sentences, offset_mapping for sentences)
71+ """
6172 inputs = self .tokenizer (sentences , return_tensors = "pt" , return_attention_mask = True , return_token_type_ids = True ,
6273 return_offsets_mapping = True ,
6374 padding = True , truncation = True )
@@ -68,6 +79,12 @@ def tokenize(self, sentences) -> Tuple[Dict[str, Any], List[List[Tuple[int, int]
6879 return inputs , offset_mapping
6980
7081 def convert_intent_logits (self , intent_logits : torch .tensor ) -> List [Dict [str , float ]]:
82+ """
83+ Convert logits from model to predicted intent,
84+
85+ :param intent_logits: output from model
86+ :return: dictionary of predicted intent
87+ """
7188 softmax_intents = self .softmax (intent_logits )
7289
7390 predicted_intents = []
@@ -87,14 +104,21 @@ def convert_intent_logits(self, intent_logits: torch.tensor) -> List[Dict[str, f
87104 predicted_intents .append ({
88105 "intent" : None if max_probability == - 1 else self .intents [max_probability ],
89106 "intent_ranking" : {
90- intent_name : probability for intent_name , probability in zip (self .intents , sentence )
107+ intent_name : probability . item () for intent_name , probability in zip (self .intents , sentence )
91108 }
92109 })
93110
94111 return predicted_intents
95112
96113 def convert_entities_logits (self , entities_logits : torch .tensor , offset_mapping : torch .tensor ) -> List [
97114 List [Dict [str , Any ]]]:
115+ """
116+ Convert logits to predicted entities
117+
118+ :param entities_logits: entities logits from model
119+ :param offset_mapping: offset mapping for sentences
120+ :return: list of predicted entities
121+ """
98122 softmax_entities = self .softmax (entities_logits )
99123
100124 predicted_entities = []
@@ -108,17 +132,23 @@ def convert_entities_logits(self, entities_logits: torch.tensor, offset_mapping:
108132 if self .entities [max_probability ] != latest_entity :
109133 predicted_entities [- 1 ].append ({
110134 "entity_name" : self .entities [max_probability ],
111- "start" : token_offset [0 ],
112- "end" : token_offset [1 ]
135+ "start" : token_offset [0 ]. item () ,
136+ "end" : token_offset [1 ]. item ()
113137 })
114138 else :
115- predicted_entities [- 1 ][- 1 ]["end" ] = token_offset [1 ]
139+ predicted_entities [- 1 ][- 1 ]["end" ] = token_offset [1 ]. item ()
116140 else :
117141 latest_entity = None
118142
119143 return predicted_entities
120144
121145 def predict (self , sentences : List [str ]) -> List [Dict [str , Any ]]:
146+ """
147+ Predict intent and entities from sentences.
148+
149+ :param sentences: list of sentences
150+ :return: list of prediction
151+ """
122152 inputs , offset_mapping = self .tokenize (sentences = sentences )
123153 outputs = self .model (** inputs )
124154 logits = outputs ["logits" ]
@@ -133,10 +163,21 @@ def predict(self, sentences: List[str]) -> List[Dict[str, Any]]:
133163 return predicted_outputs
134164
135165 def save_pretrained (self , directory : str ):
166+ """
167+ Save model and tokenizer to directory
168+
169+ :param directory: path to save folder
170+ :return: None
171+ """
136172 self .model .save_pretrained (directory )
137173 self .tokenizer .save_pretrained (directory )
138174
139175 def train_model (self , save_folder : str = "latest_model" ):
176+ """
177+ Create trainer, train and save best model to save_folder
178+ :param save_folder: path to save folder
179+ :return: None
180+ """
140181 dataset_folder = self .dataset_config ["dataset_folder" ]
141182 if not path .exists (dataset_folder ):
142183 raise ValueError (f"Folder { dataset_folder } is not exists" )
0 commit comments