|
45 | 45 | import json |
46 | 46 | import csv |
47 | 47 | import random |
| 48 | +import numpy as np |
| 49 | + |
| 50 | +try: |
| 51 | + from pandas import DataFrame |
| 52 | + PANDAS_READY = True |
| 53 | +except ImportError: |
| 54 | + PANDAS_READY = False |
48 | 55 |
|
49 | 56 |
|
50 | 57 | from bigml.util import invert_dictionary, python_map_type, find_locale |
51 | 58 | from bigml.util import DEFAULT_LOCALE |
52 | 59 | from bigml.api_handlers.resourcehandler import get_resource_type, get_fields |
53 | 60 | from bigml.constants import ( |
54 | 61 | SOURCE_PATH, DATASET_PATH, SUPERVISED_PATHS, FUSION_PATH, |
55 | | - RESOURCES_WITH_FIELDS, DEFAULT_MISSING_TOKENS, REGIONS) |
| 62 | + RESOURCES_WITH_FIELDS, DEFAULT_MISSING_TOKENS, REGIONS, CATEGORICAL) |
56 | 63 | from bigml.io import UnicodeReader, UnicodeWriter |
57 | 64 |
|
58 | 65 | LIST_LIMIT = 10 |
@@ -193,6 +200,32 @@ def get_new_fields(output_fields): |
193 | 200 | return new_fields |
194 | 201 |
|
195 | 202 |
|
| 203 | +def one_hot_code(value, field, decode=False): |
| 204 | + """Translating into codes categorical values. The codes are the index |
| 205 | + of the value in the list of categories read from the fields summary. |
| 206 | + Decode set to True will cause the code to be translated to the value""" |
| 207 | + |
| 208 | + try: |
| 209 | + categories = [cat[0] for cat in field["summary"]["categories"]] |
| 210 | + except KeyError: |
| 211 | + raise KeyError("Failed to find the categories list. Check the field" |
| 212 | + " information.") |
| 213 | + |
| 214 | + if decode: |
| 215 | + try: |
| 216 | + result = categories[int(value)] |
| 217 | + except KeyError: |
| 218 | + raise KeyError("Code not found in the categories list. %s" % |
| 219 | + categories) |
| 220 | + else: |
| 221 | + try: |
| 222 | + result = categories.index(value) |
| 223 | + except ValueError: |
| 224 | + raise ValueError("The '%s' value is not found in the categories " |
| 225 | + "list: %s" % (value, categories)) |
| 226 | + return result |
| 227 | + |
| 228 | + |
196 | 229 | class Fields(): |
197 | 230 | """A class to deal with BigML auto-generated ids. |
198 | 231 |
|
@@ -483,6 +516,77 @@ def stats(self, field_name): |
483 | 516 | summary = self.fields[field_id].get('summary', {}) |
484 | 517 | return summary |
485 | 518 |
|
| 519 | + def objective_field_info(self): |
| 520 | + """Returns the fields structure for the objective field""" |
| 521 | + if self.objective_field is None: |
| 522 | + return None |
| 523 | + objective_id = self.field_id(self.objective_field) |
| 524 | + return {objective_id: self.fields[objective_id]} |
| 525 | + |
| 526 | + def sorted_field_ids(self, objective=False): |
| 527 | + """List of field IDs ordered by column number. If objective is |
| 528 | + set to False, the objective field will be excluded. |
| 529 | + """ |
| 530 | + fields = {} |
| 531 | + fields.update(self.fields_by_column_number) |
| 532 | + if not objective and self.objective_field is not None: |
| 533 | + del(fields[self.objective_field]) |
| 534 | + field_ids = fields.values() |
| 535 | + return field_ids |
| 536 | + |
| 537 | + def to_numpy(self, input_data_list, objective=False): |
| 538 | + """Transforming input data to numpy syntax. Fields are sorted |
| 539 | + in the dataset order and categorical fields are one-hot encoded. |
| 540 | + If objective set to False, the objective field will not be included""" |
| 541 | + if PANDAS_READY and isinstance(input_data_list, DataFrame): |
| 542 | + inner_data_list = input_data_list.to_dict('records') |
| 543 | + else: |
| 544 | + inner_data_list = input_data_list |
| 545 | + field_ids = self.sorted_field_ids(objective=objective) |
| 546 | + np_input_list = np.empty(shape=(len(input_data_list), |
| 547 | + len(field_ids))) |
| 548 | + for index, input_data in enumerate(inner_data_list): |
| 549 | + np_input = np.array([]) |
| 550 | + for field_id in field_ids: |
| 551 | + field_input = input_data.get(field_id, |
| 552 | + input_data.get(self.field_name(field_id))) |
| 553 | + field = self.fields[field_id] |
| 554 | + if field["optype"] == CATEGORICAL: |
| 555 | + field_input = one_hot_code(field_input, field) |
| 556 | + np_input = np.append(np_input, field_input) |
| 557 | + np_input_list[index] = np_input |
| 558 | + return np_input_list |
| 559 | + |
| 560 | + def from_numpy(self, np_data_list, objective=False, by_name=True): |
| 561 | + """Transforming input data from numpy syntax. Fields are sorted |
| 562 | + in the dataset order and categorical fields are one-hot encoded.""" |
| 563 | + input_data_list = [] |
| 564 | + field_ids = self.sorted_field_ids(objective=objective) |
| 565 | + for np_data in np_data_list: |
| 566 | + if len(np_data) != len(field_ids): |
| 567 | + raise ValueError("Wrong number of features in data: %s" |
| 568 | + " found, %s expected" % (len(np_data), len(field_ids))) |
| 569 | + input_data = {} |
| 570 | + for index, field_id in enumerate(field_ids): |
| 571 | + field_input = None if np.isnan(np_data[index]) else \ |
| 572 | + np_data[index] |
| 573 | + field = self.fields[field_id] |
| 574 | + if field["optype"] == CATEGORICAL: |
| 575 | + field_input = one_hot_code(field_input, field, decode=True) |
| 576 | + if by_name: |
| 577 | + field_id = self.fields[field_id]["name"] |
| 578 | + input_data.update({field_id: field_input}) |
| 579 | + input_data_list.append(input_data) |
| 580 | + return input_data_list |
| 581 | + |
| 582 | + def one_hot_codes(self, field_name): |
| 583 | + """Returns the codes used for every category in a categorical field""" |
| 584 | + field = self.fields[self.field_id(field_name)] |
| 585 | + if field["optype"] != CATEGORICAL: |
| 586 | + raise ValueError("Only categorical fields are encoded") |
| 587 | + categories = [cat[0] for cat in field["summary"]["categories"]] |
| 588 | + return dict(zip(categories, range(0, len(categories)))) |
| 589 | + |
486 | 590 | def summary_csv(self, filename=None): |
487 | 591 | """Summary of the contents of the fields |
488 | 592 |
|
|
0 commit comments