|
4 | 4 | from matplotlib import pyplot as plt |
5 | 5 | from abc import ABCMeta, abstractmethod |
6 | 6 |
|
| 7 | +from sklearn.utils.extmath import softmax |
7 | 8 | from sklearn.preprocessing import LabelBinarizer |
8 | 9 | from sklearn.utils import check_X_y, column_or_1d |
9 | 10 | from sklearn.utils.validation import check_is_fitted |
@@ -475,41 +476,53 @@ def decision_path_indice(self, x, node_id): |
475 | 476 |
|
476 | 477 | def decision_path(self, x): |
477 | 478 |
|
| 479 | + check_is_fitted(self, "tree") |
| 480 | + |
478 | 481 | n_samples = x.shape[0] |
479 | 482 | path_all = np.zeros((n_samples, self.node_count)) |
480 | | - for idx, row in enumerate(x): |
| 483 | + for node_id in self.leaf_idx_list: |
481 | 484 | path = [] |
482 | | - node = self.tree[1] |
483 | | - while not node['is_leaf']: |
484 | | - path.append(node["node_id"] - 1) |
485 | | - if row[node['feature']] <= node['threshold']: |
486 | | - node = self.tree[node['left_child_id']] |
| 485 | + idx = node_id |
| 486 | + sample_indice = np.ones((x.shape[0], )).astype(np.bool) |
| 487 | + while True: |
| 488 | + path.append(idx - 1) |
| 489 | + current_node = self.tree[idx] |
| 490 | + if current_node["parent_id"] is None: |
| 491 | + break |
487 | 492 | else: |
488 | | - node = self.tree[node['right_child_id']] |
489 | | - path.append(node["node_id"] - 1) |
490 | | - path_all[idx][path] = 1 |
| 493 | + parent_node = self.tree[current_node["parent_id"]] |
| 494 | + if current_node["is_left"]: |
| 495 | + sample_indice = np.logical_and(sample_indice, x[:, parent_node["feature"]] <= parent_node["threshold"]) |
| 496 | + else: |
| 497 | + sample_indice = np.logical_and(sample_indice, x[:, parent_node["feature"]] > parent_node["threshold"]) |
| 498 | + idx = current_node["parent_id"] |
| 499 | + if sample_indice.sum() > 0: |
| 500 | + path_all[np.ix_(np.where(sample_indice)[0], path)] = 1 |
491 | 501 | return path_all |
492 | 502 |
|
493 | 503 | def decision_function(self, x): |
494 | 504 |
|
495 | 505 | check_is_fitted(self, "tree") |
496 | 506 |
|
497 | | - leaf_idx = [] |
498 | 507 | x = np.array(x) |
499 | | - for row in x: |
500 | | - node = self.tree[1] |
501 | | - while not node['is_leaf']: |
502 | | - if row[node['feature']] <= node['threshold']: |
503 | | - node = self.tree[node['left_child_id']] |
504 | | - else: |
505 | | - node = self.tree[node['right_child_id']] |
506 | | - leaf_idx.append(node['node_id']) |
507 | | - |
508 | 508 | n_samples = x.shape[0] |
509 | 509 | pred = np.zeros((n_samples)) |
510 | | - for node_id in np.unique(leaf_idx): |
511 | | - sample_indice = np.array(leaf_idx) == node_id |
512 | | - pred[sample_indice] = self.tree[node_id]['predict_func'](x[sample_indice, :]).ravel() |
| 510 | + for node_id in self.leaf_idx_list: |
| 511 | + idx = node_id |
| 512 | + sample_indice = np.ones((x.shape[0], )).astype(np.bool) |
| 513 | + while True: |
| 514 | + current_node = self.tree[idx] |
| 515 | + if current_node["parent_id"] is None: |
| 516 | + break |
| 517 | + else: |
| 518 | + parent_node = self.tree[current_node["parent_id"]] |
| 519 | + if current_node["is_left"]: |
| 520 | + sample_indice = np.logical_and(sample_indice, x[:, parent_node["feature"]] <= parent_node["threshold"]) |
| 521 | + else: |
| 522 | + sample_indice = np.logical_and(sample_indice, x[:, parent_node["feature"]] > parent_node["threshold"]) |
| 523 | + idx = current_node["parent_id"] |
| 524 | + if sample_indice.sum() > 0: |
| 525 | + pred[sample_indice] = self.tree[node_id]['predict_func'](x[sample_indice, :]).ravel() |
513 | 526 | return pred |
514 | 527 |
|
515 | 528 |
|
@@ -647,9 +660,10 @@ def evaluate_estimator(self, estimator, x, y): |
647 | 660 | return loss |
648 | 661 |
|
649 | 662 | def predict_proba(self, x): |
650 | | - proba = self.decision_function(x).reshape(-1, 1) |
651 | | - return np.hstack([1 - proba, proba]) |
| 663 | + pred = self.decision_function(x).reshape(-1, 1) |
| 664 | + pred_proba = softmax(np.hstack([-pred, pred]) / 2, copy=False) |
| 665 | + return pred_proba |
652 | 666 |
|
653 | 667 | def predict(self, x): |
654 | | - pred_proba = self.decision_function(x) |
| 668 | + pred_proba = self.predict_proba(x) |
655 | 669 | return self._label_binarizer.inverse_transform(pred_proba) |
0 commit comments