1- from typing import Any , Dict , Optional , Tuple , Iterable , List , Union
1+ from typing import Any , Dict , Optional , Tuple , Iterable , List , Union , Set
22from medcat .tokenizers .meta_cat_tokenizers import TokenizerWrapperBase
33import copy
44import logging
@@ -153,8 +153,100 @@ def prepare_for_oversampled_data(data: List,
153153 return data_sampled
154154
155155
156+ def find_alternate_classname (category_value2id : Dict , category_values : Set , alternative_class_names : List [List ]) -> Dict :
157+ """Helper function to find and map to alternative class names for the given category.
158+ Example: For Temporality category, 'Recent' is an alternative to 'Present'.
159+
160+ Args:
161+ category_value2id (Dict):
162+ The pre-defined category_value2id
163+ category_values (Set):
164+ Contains the classes (labels) found in the data
165+ alternative_class_names (List):
166+ Contains the mapping of alternative class names
167+
168+ Returns:
169+ category_value2id (Dict):
170+ Updated category_value2id with keys corresponding to alternative class names
171+
172+ Raises:
173+ Exception: If no alternatives are found for labels in category_value2id that don't match any of the labels in the data
174+ Exception: If the alternatives defined for labels in category_value2id that don't match any of the labels in the data
175+ """
176+
177+ updated_category_value2id = {}
178+ for _class in category_value2id .keys ():
179+ if _class in category_values :
180+ updated_category_value2id [_class ] = category_value2id [_class ]
181+ else :
182+ found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map ]
183+ failed_to_find = False
184+ if len (found_in ) != 0 :
185+ class_name_matched = [label for label in found_in [0 ] if label in category_values ]
186+ if len (class_name_matched ) != 0 :
187+ updated_category_value2id [class_name_matched [0 ]] = category_value2id [_class ]
188+ logger .info ("Class name '%s' does not exist in the data; however a variation of it "
189+ "'%s' is present; updating it..." , _class , class_name_matched [0 ])
190+ else :
191+ failed_to_find = True
192+ else :
193+ failed_to_find = True
194+ if failed_to_find :
195+ raise Exception ("The classes set in the config are not the same as the one found in the data. "
196+ "The classes present in the config vs the ones found in the data - "
197+ f"{ set (category_value2id .keys ())} , { category_values } . Additionally, ensure the "
198+ "populate the 'alternative_class_names' attribute to accommodate for variations." )
199+ category_value2id = copy .deepcopy (updated_category_value2id )
200+ logger .info ("Updated categoryvalue2id mapping - %s" , category_value2id )
201+ return category_value2id
202+
203+
204+ def undersample_data (data : List , category_value2id : Dict , label_data_ ,config ) -> List :
205+ """Undersamples the data for 2 phase learning
206+
207+ Args:
208+ data (List):
209+ Output of `prepare_from_json`.
210+ category_value2id(Dict):
211+ Map from category_value to id.
212+ label_data_:
213+ Map that stores the number of samples for each label
214+ config:
215+ MetaCAT config
216+
217+ Returns:
218+ data_undersampled (list):
219+ Return the data created for 2 phase learning) with integers inplace of strings for category values
220+ """
221+
222+ data_undersampled = []
223+ category_undersample = config .model .category_undersample
224+ if category_undersample is None or category_undersample == '' :
225+ min_label = min (label_data_ .values ())
226+
227+ else :
228+ if category_undersample not in label_data_ .keys () and category_undersample in category_value2id .keys ():
229+ min_label = label_data_ [category_value2id [category_undersample ]]
230+ else :
231+ min_label = label_data_ [category_undersample ]
232+
233+ label_data_counter = {v : 0 for v in category_value2id .values ()}
234+
235+ for sample in data :
236+ if label_data_counter [sample [- 1 ]] < min_label :
237+ data_undersampled .append (sample )
238+ label_data_counter [sample [- 1 ]] += 1
239+
240+ label_data = {v : 0 for v in category_value2id .values ()}
241+ for i in range (len (data_undersampled )):
242+ if data_undersampled [i ][2 ] in category_value2id .values ():
243+ label_data [data_undersampled [i ][2 ]] = label_data [data_undersampled [i ][2 ]] + 1
244+ logger .info ("Updated number of samples per label (for 2-phase learning): %s" , label_data )
245+ return data_undersampled
246+
247+
156248def encode_category_values (data : Dict , existing_category_value2id : Optional [Dict ] = None ,
157- category_undersample = None , alternative_class_names : List [List ] = []) -> Tuple :
249+ alternative_class_names : List [List ] = [], config = None ) -> Tuple :
158250 """Converts the category values in the data outputted by `prepare_from_json`
159251 into integer values.
160252
@@ -163,22 +255,24 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
163255 Output of `prepare_from_json`.
164256 existing_category_value2id(Optional[Dict]):
165257 Map from category_value to id (old/existing).
166- category_undersample:
167- Name of class that should be used to undersample the data (for 2 phase learning)
168258 alternative_class_names:
169259 Map that stores the variations of possible class names for the given category (task)
260+ config:
261+ MetaCAT config
170262
171263 Returns:
172- dict :
264+ data (list) :
173265 New data with integers inplace of strings for category values.
174- dict :
266+ data_undersampled (list) :
175267 New undersampled data (for 2 phase learning) with integers inplace of strings for category values
176- dict:
268+ category_value2id ( dict) :
177269 Map from category value to ID for all categories in the data.
178270
179271 Raises:
180- Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data
272+ Exception: If the number of classes in config do not match the number of classes found in the data
273+ Exception: If category_value2id is pre-defined, its labels do not match the labels found in the data and alternative_class_names is empty
181274 """
275+
182276 data = list (data )
183277 if existing_category_value2id is not None :
184278 category_value2id = existing_category_value2id
@@ -187,43 +281,29 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
187281
188282 category_values = set ([x [2 ] for x in data ])
189283
190- # If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data
191- if len (category_value2id ) != 0 and set (category_value2id .keys ()) != category_values :
192- # if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
193- if len (alternative_class_names ) == 0 :
194- # Raise an exception since the labels don't match
284+ if config :
285+ if len (category_values ) != config .model .nclasses :
195286 raise Exception (
196- "The classes set in the config are not the same as the one found in the data. "
197- "The classes present in the config vs the ones found in the data - "
198- f"{ set (category_value2id .keys ())} , { category_values } . Additionally, ensure the populate the "
199- "'alternative_class_names' attribute to accommodate for variations." )
200- updated_category_value2id = {}
201- for _class in category_value2id .keys ():
202- if _class in category_values :
203- updated_category_value2id [_class ] = category_value2id [_class ]
204- else :
205- found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map ]
206- failed_to_find = False
207- if len (found_in ) != 0 :
208- class_name_matched = [label for label in found_in [0 ] if label in category_values ]
209- if len (class_name_matched ) != 0 :
210- updated_category_value2id [class_name_matched [0 ]] = category_value2id [_class ]
211- logger .info ("Class name '%s' does not exist in the data; however a variation of it "
212- "'%s' is present; updating it..." , _class , class_name_matched [0 ])
213- else :
214- failed_to_find = True
215- else :
216- failed_to_find = True
217- if failed_to_find :
218- raise Exception ("The classes set in the config are not the same as the one found in the data. "
219- "The classes present in the config vs the ones found in the data - "
220- f"{ set (category_value2id .keys ())} , { category_values } . Additionally, ensure the "
221- "populate the 'alternative_class_names' attribute to accommodate for variations." )
222- category_value2id = copy .deepcopy (updated_category_value2id )
223- logger .info ("Updated categoryvalue2id mapping - %s" , category_value2id )
287+ "The number of classes found in the data - %s does not match the number of classes defined in the config - %s (config.model.nclasses). Please update the number of classes and initialise the model again." ,
288+ len (category_values ), config .model .nclasses )
289+
290+ # If categoryvalue2id is pre-defined or if all the classes aren't mentioned
291+ if len (category_value2id ) != 0 :
292+ # making sure it is same as the labels found in the data
293+ if set (category_value2id .keys ()) != category_values :
294+ # if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
295+ if len (alternative_class_names ) == 0 :
296+ # Raise an exception since the labels don't match
297+ raise Exception (
298+ "The classes set in the config are not the same as the one found in the data. "
299+ "The classes present in the config vs the ones found in the data - "
300+ f"{ set (category_value2id .keys ())} , { category_values } . Additionally, ensure the populate the "
301+ "'alternative_class_names' attribute to accommodate for variations." )
302+
303+ category_value2id = find_alternate_classname (category_value2id , category_values , alternative_class_names )
224304
225305 # Else create the mapping from the labels found in the data
226- else :
306+ if len ( category_value2id ) != len ( category_values ) :
227307 for c in category_values :
228308 if c not in category_value2id :
229309 category_value2id [c ] = len (category_value2id )
@@ -239,30 +319,11 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
239319 if data [i ][2 ] in category_value2id .values ():
240320 label_data_ [data [i ][2 ]] = label_data_ [data [i ][2 ]] + 1
241321
242- logger .info ("Original number of samples per label: %s" ,label_data_ )
243- # Undersampling data
244- if category_undersample is None or category_undersample == '' :
245- min_label = min (label_data_ .values ())
246-
247- else :
248- if category_undersample not in label_data_ .keys () and category_undersample in category_value2id .keys ():
249- min_label = label_data_ [category_value2id [category_undersample ]]
250- else :
251- min_label = label_data_ [category_undersample ]
322+ logger .info ("Original number of samples per label: %s" , label_data_ )
252323
253324 data_undersampled = []
254- label_data_counter = {v : 0 for v in category_value2id .values ()}
255-
256- for sample in data :
257- if label_data_counter [sample [- 1 ]] < min_label :
258- data_undersampled .append (sample )
259- label_data_counter [sample [- 1 ]] += 1
260-
261- label_data = {v : 0 for v in category_value2id .values ()}
262- for i in range (len (data_undersampled )):
263- if data_undersampled [i ][2 ] in category_value2id .values ():
264- label_data [data_undersampled [i ][2 ]] = label_data [data_undersampled [i ][2 ]] + 1
265- logger .info ("Updated number of samples per label (for 2-phase learning): %s" ,label_data )
325+ if config and config .model .phase_number != 0 :
326+ data_undersampled = undersample_data (data , category_value2id , label_data_ , config )
266327
267328 return data , data_undersampled , category_value2id
268329
0 commit comments