Skip to content

Commit aa919a9

Browse files
shubham-s-agarwalmart-r
authored andcommitted
Pushing update for MetaCAT (#155)
* Pushing update for metacat Includes changes to data_utils * Update data_utils.py * Update data_utils.py * Update data_utils.py Creating helper functions for checking alternative class names and undersampling data * Update data_utils.py * Update data_utils.py Changes for flake8 * Update data_utils.py * Update data_utils.py
1 parent 99b6bee commit aa919a9

File tree

2 files changed

+128
-67
lines changed

2 files changed

+128
-67
lines changed

v1/medcat/medcat/meta_cat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
252252
"The category name does not exist in this json file. You've provided '{}', "
253253
"while the possible options are: {}. Additionally, ensure the populate the "
254254
"'alternative_category_names' attribute to accommodate for variations.".format(
255-
category_name, " | ".join(list(data.keys()))))
255+
g_config['category_name'], " | ".join(list(data.keys()))))
256256

257257
data = data[category_name]
258258
if data_oversampled:
@@ -263,12 +263,12 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
263263
if not category_value2id:
264264
# Encode the category values
265265
full_data, data_undersampled, category_value2id = encode_category_values(data,
266-
category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names'])
266+
alternative_class_names=g_config['alternative_class_names'],config=self.config)
267267
else:
268268
# We already have everything, just get the data
269269
full_data, data_undersampled, category_value2id = encode_category_values(data,
270270
existing_category_value2id=category_value2id,
271-
category_undersample=self.config.model.category_undersample,alternative_class_names=g_config['alternative_class_names'])
271+
alternative_class_names=g_config['alternative_class_names'],config=self.config)
272272
g_config['category_value2id'] = category_value2id
273273
self.config.model['nclasses'] = len(category_value2id)
274274

v1/medcat/medcat/utils/meta_cat/data_utils.py

Lines changed: 125 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Optional, Tuple, Iterable, List, Union
1+
from typing import Any, Dict, Optional, Tuple, Iterable, List, Union, Set
22
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBase
33
import copy
44
import logging
@@ -153,8 +153,100 @@ def prepare_for_oversampled_data(data: List,
153153
return data_sampled
154154

155155

156+
def find_alternate_classname(category_value2id: Dict, category_values: Set, alternative_class_names: List[List]) -> Dict:
157+
"""Helper function to find and map to alternative class names for the given category.
158+
Example: For Temporality category, 'Recent' is an alternative to 'Present'.
159+
160+
Args:
161+
category_value2id (Dict):
162+
The pre-defined category_value2id
163+
category_values (Set):
164+
Contains the classes (labels) found in the data
165+
alternative_class_names (List):
166+
Contains the mapping of alternative class names
167+
168+
Returns:
169+
category_value2id (Dict):
170+
Updated category_value2id with keys corresponding to alternative class names
171+
172+
Raises:
173+
Exception: If no alternatives are found for labels in category_value2id that don't match any of the labels in the data
174+
Exception: If the alternatives defined for labels in category_value2id that don't match any of the labels in the data
175+
"""
176+
177+
updated_category_value2id = {}
178+
for _class in category_value2id.keys():
179+
if _class in category_values:
180+
updated_category_value2id[_class] = category_value2id[_class]
181+
else:
182+
found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map]
183+
failed_to_find = False
184+
if len(found_in) != 0:
185+
class_name_matched = [label for label in found_in[0] if label in category_values]
186+
if len(class_name_matched) != 0:
187+
updated_category_value2id[class_name_matched[0]] = category_value2id[_class]
188+
logger.info("Class name '%s' does not exist in the data; however a variation of it "
189+
"'%s' is present; updating it...", _class, class_name_matched[0])
190+
else:
191+
failed_to_find = True
192+
else:
193+
failed_to_find = True
194+
if failed_to_find:
195+
raise Exception("The classes set in the config are not the same as the one found in the data. "
196+
"The classes present in the config vs the ones found in the data - "
197+
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the "
198+
"populate the 'alternative_class_names' attribute to accommodate for variations.")
199+
category_value2id = copy.deepcopy(updated_category_value2id)
200+
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
201+
return category_value2id
202+
203+
204+
def undersample_data(data: List, category_value2id: Dict, label_data_,config) -> List:
205+
"""Undersamples the data for 2 phase learning
206+
207+
Args:
208+
data (List):
209+
Output of `prepare_from_json`.
210+
category_value2id(Dict):
211+
Map from category_value to id.
212+
label_data_:
213+
Map that stores the number of samples for each label
214+
config:
215+
MetaCAT config
216+
217+
Returns:
218+
data_undersampled (list):
219+
Return the data created for 2 phase learning) with integers inplace of strings for category values
220+
"""
221+
222+
data_undersampled = []
223+
category_undersample = config.model.category_undersample
224+
if category_undersample is None or category_undersample == '':
225+
min_label = min(label_data_.values())
226+
227+
else:
228+
if category_undersample not in label_data_.keys() and category_undersample in category_value2id.keys():
229+
min_label = label_data_[category_value2id[category_undersample]]
230+
else:
231+
min_label = label_data_[category_undersample]
232+
233+
label_data_counter = {v: 0 for v in category_value2id.values()}
234+
235+
for sample in data:
236+
if label_data_counter[sample[-1]] < min_label:
237+
data_undersampled.append(sample)
238+
label_data_counter[sample[-1]] += 1
239+
240+
label_data = {v: 0 for v in category_value2id.values()}
241+
for i in range(len(data_undersampled)):
242+
if data_undersampled[i][2] in category_value2id.values():
243+
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
244+
logger.info("Updated number of samples per label (for 2-phase learning): %s", label_data)
245+
return data_undersampled
246+
247+
156248
def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict] = None,
157-
category_undersample=None, alternative_class_names: List[List] = []) -> Tuple:
249+
alternative_class_names: List[List] = [], config=None) -> Tuple:
158250
"""Converts the category values in the data outputted by `prepare_from_json`
159251
into integer values.
160252
@@ -163,22 +255,24 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
163255
Output of `prepare_from_json`.
164256
existing_category_value2id(Optional[Dict]):
165257
Map from category_value to id (old/existing).
166-
category_undersample:
167-
Name of class that should be used to undersample the data (for 2 phase learning)
168258
alternative_class_names:
169259
Map that stores the variations of possible class names for the given category (task)
260+
config:
261+
MetaCAT config
170262
171263
Returns:
172-
dict:
264+
data (list):
173265
New data with integers inplace of strings for category values.
174-
dict:
266+
data_undersampled (list):
175267
New undersampled data (for 2 phase learning) with integers inplace of strings for category values
176-
dict:
268+
category_value2id (dict):
177269
Map from category value to ID for all categories in the data.
178270
179271
Raises:
180-
Exception: If categoryvalue2id is pre-defined and its labels do not match the labels found in the data
272+
Exception: If the number of classes in config do not match the number of classes found in the data
273+
Exception: If category_value2id is pre-defined, its labels do not match the labels found in the data and alternative_class_names is empty
181274
"""
275+
182276
data = list(data)
183277
if existing_category_value2id is not None:
184278
category_value2id = existing_category_value2id
@@ -187,43 +281,29 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
187281

188282
category_values = set([x[2] for x in data])
189283

190-
# If categoryvalue2id is pre-defined, then making sure it is same as the labels found in the data
191-
if len(category_value2id) != 0 and set(category_value2id.keys()) != category_values:
192-
# if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
193-
if len(alternative_class_names) == 0:
194-
# Raise an exception since the labels don't match
284+
if config:
285+
if len(category_values) != config.model.nclasses:
195286
raise Exception(
196-
"The classes set in the config are not the same as the one found in the data. "
197-
"The classes present in the config vs the ones found in the data - "
198-
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the "
199-
"'alternative_class_names' attribute to accommodate for variations.")
200-
updated_category_value2id = {}
201-
for _class in category_value2id.keys():
202-
if _class in category_values:
203-
updated_category_value2id[_class] = category_value2id[_class]
204-
else:
205-
found_in = [sub_map for sub_map in alternative_class_names if _class in sub_map]
206-
failed_to_find = False
207-
if len(found_in) != 0:
208-
class_name_matched = [label for label in found_in[0] if label in category_values]
209-
if len(class_name_matched) != 0:
210-
updated_category_value2id[class_name_matched[0]] = category_value2id[_class]
211-
logger.info("Class name '%s' does not exist in the data; however a variation of it "
212-
"'%s' is present; updating it...", _class, class_name_matched[0])
213-
else:
214-
failed_to_find = True
215-
else:
216-
failed_to_find = True
217-
if failed_to_find:
218-
raise Exception("The classes set in the config are not the same as the one found in the data. "
219-
"The classes present in the config vs the ones found in the data - "
220-
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the "
221-
"populate the 'alternative_class_names' attribute to accommodate for variations.")
222-
category_value2id = copy.deepcopy(updated_category_value2id)
223-
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
287+
"The number of classes found in the data - %s does not match the number of classes defined in the config - %s (config.model.nclasses). Please update the number of classes and initialise the model again.",
288+
len(category_values), config.model.nclasses)
289+
290+
# If categoryvalue2id is pre-defined or if all the classes aren't mentioned
291+
if len(category_value2id) != 0:
292+
# making sure it is same as the labels found in the data
293+
if set(category_value2id.keys()) != category_values:
294+
# if categoryvalue2id doesn't match the labels in the data, then 'alternative_class_names' has to be defined to check for variations
295+
if len(alternative_class_names) == 0:
296+
# Raise an exception since the labels don't match
297+
raise Exception(
298+
"The classes set in the config are not the same as the one found in the data. "
299+
"The classes present in the config vs the ones found in the data - "
300+
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the populate the "
301+
"'alternative_class_names' attribute to accommodate for variations.")
302+
303+
category_value2id = find_alternate_classname(category_value2id, category_values, alternative_class_names)
224304

225305
# Else create the mapping from the labels found in the data
226-
else:
306+
if len(category_value2id) != len(category_values):
227307
for c in category_values:
228308
if c not in category_value2id:
229309
category_value2id[c] = len(category_value2id)
@@ -239,30 +319,11 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
239319
if data[i][2] in category_value2id.values():
240320
label_data_[data[i][2]] = label_data_[data[i][2]] + 1
241321

242-
logger.info("Original number of samples per label: %s",label_data_)
243-
# Undersampling data
244-
if category_undersample is None or category_undersample == '':
245-
min_label = min(label_data_.values())
246-
247-
else:
248-
if category_undersample not in label_data_.keys() and category_undersample in category_value2id.keys():
249-
min_label = label_data_[category_value2id[category_undersample]]
250-
else:
251-
min_label = label_data_[category_undersample]
322+
logger.info("Original number of samples per label: %s", label_data_)
252323

253324
data_undersampled = []
254-
label_data_counter = {v: 0 for v in category_value2id.values()}
255-
256-
for sample in data:
257-
if label_data_counter[sample[-1]] < min_label:
258-
data_undersampled.append(sample)
259-
label_data_counter[sample[-1]] += 1
260-
261-
label_data = {v: 0 for v in category_value2id.values()}
262-
for i in range(len(data_undersampled)):
263-
if data_undersampled[i][2] in category_value2id.values():
264-
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
265-
logger.info("Updated number of samples per label (for 2-phase learning): %s",label_data)
325+
if config and config.model.phase_number != 0:
326+
data_undersampled = undersample_data(data, category_value2id, label_data_, config)
266327

267328
return data, data_undersampled, category_value2id
268329

0 commit comments

Comments
 (0)