Skip to content

Commit 08d53bc

Browse files
committed
Fixing and Updating Adaboost algorithm
1 parent 6233abc commit 08d53bc

File tree

1 file changed

+73
-17
lines changed

1 file changed

+73
-17
lines changed

machine_learning/adaboost.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,54 +12,83 @@
1212
array([0, 1])
1313
"""
1414

15-
import numpy as np
1615
from typing import Any
1716

17+
import numpy as np
18+
1819

1920
class AdaBoost:
2021
def __init__(self, n_estimators: int = 50) -> None:
21-
"""Initialize AdaBoost classifier.
22+
"""
23+
Initialize AdaBoost classifier.
24+
2225
Args:
23-
n_estimators: Number of boosting rounds.
26+
n_estimators: Number of boosting rounds (weak learners).
2427
"""
2528
self.n_estimators: int = n_estimators
26-
self.alphas: list[float] = [] # Weights for each weak learner
27-
self.models: list[dict[str, Any]] = [] # List of weak learners (stumps)
29+
self.alphas: list[float] = [] # Weights assigned to each weak learner
30+
self.models: list[dict[str, Any]] = [] # Stores each decision stump
2831

2932
def fit(self, feature_matrix: np.ndarray, target: np.ndarray) -> None:
30-
"""Fit AdaBoost model.
33+
"""
34+
Train AdaBoost model using decision stumps.
35+
3136
Args:
32-
feature_matrix: (n_samples, n_features) feature matrix
33-
target: (n_samples,) labels (0 or 1)
37+
feature_matrix: 2D array of shape (n_samples, n_features)
38+
target: 1D array of binary labels (0 or 1)
3439
"""
3540
n_samples, _ = feature_matrix.shape
41+
42+
# Initialize uniform sample weights
3643
sample_weights = np.ones(n_samples) / n_samples
44+
45+
# Reset model state
3746
self.models = []
3847
self.alphas = []
48+
49+
# Convert labels to {-1, 1} for boosting
3950
y_signed = np.where(target == 0, -1, 1)
51+
4052
for _ in range(self.n_estimators):
53+
# Train a weighted decision stump
4154
stump = self._build_stump(feature_matrix, y_signed, sample_weights)
4255
pred = stump["pred"]
4356
err = stump["error"]
57+
58+
# Compute alpha (learner weight) with numerical stability
4459
alpha = 0.5 * np.log((1 - err) / (err + 1e-10))
60+
61+
# Update sample weights to focus on misclassified points
4562
sample_weights *= np.exp(-alpha * y_signed * pred)
4663
sample_weights /= np.sum(sample_weights)
64+
65+
# Store the stump and its weight
4766
self.models.append(stump)
4867
self.alphas.append(alpha)
4968

5069
def predict(self, feature_matrix: np.ndarray) -> np.ndarray:
51-
"""Predict class labels for samples in feature_matrix.
70+
"""
71+
Predict binary class labels for input samples.
72+
5273
Args:
53-
feature_matrix: (n_samples, n_features) feature matrix
74+
feature_matrix: 2D array of shape (n_samples, n_features)
75+
5476
Returns:
55-
(n_samples,) predicted labels (0 or 1)
77+
1D array of predicted labels (0 or 1)
5678
"""
5779
clf_preds = np.zeros(feature_matrix.shape[0])
80+
81+
# Aggregate predictions from all stumps
5882
for alpha, stump in zip(self.alphas, self.models):
5983
pred = self._stump_predict(
60-
feature_matrix, stump["feature"], stump["threshold"], stump["polarity"]
84+
feature_matrix,
85+
stump["feature"],
86+
stump["threshold"],
87+
stump["polarity"],
6188
)
6289
clf_preds += alpha * pred
90+
91+
# Final prediction: sign of weighted sum
6392
return np.where(clf_preds >= 0, 1, 0)
6493

6594
def _build_stump(
@@ -68,16 +97,30 @@ def _build_stump(
6897
target_signed: np.ndarray,
6998
sample_weights: np.ndarray,
7099
) -> dict[str, Any]:
71-
"""Find the best decision stump for current weights."""
100+
"""
101+
Build the best decision stump for current sample weights.
102+
103+
Returns:
104+
Dictionary containing stump parameters and predictions.
105+
"""
72106
_, n_features = feature_matrix.shape
73107
min_error = float("inf")
74108
best_stump: dict[str, Any] = {}
109+
110+
# Iterate over all features and thresholds
75111
for feature in range(n_features):
76112
thresholds = np.unique(feature_matrix[:, feature])
77113
for threshold in thresholds:
78114
for polarity in [1, -1]:
79-
pred = self._stump_predict(feature_matrix, feature, threshold, polarity)
115+
pred = self._stump_predict(
116+
feature_matrix,
117+
feature,
118+
threshold,
119+
polarity,
120+
)
80121
error = np.sum(sample_weights * (pred != target_signed))
122+
123+
# Keep stump with lowest weighted error
81124
if error < min_error:
82125
min_error = error
83126
best_stump = {
@@ -87,15 +130,28 @@ def _build_stump(
87130
"error": error,
88131
"pred": pred.copy(),
89132
}
133+
90134
return best_stump
91135

92136
def _stump_predict(
93-
self, feature_matrix: np.ndarray, feature: int, threshold: float, polarity: int
137+
self,
138+
feature_matrix: np.ndarray,
139+
feature: int,
140+
threshold: float,
141+
polarity: int,
94142
) -> np.ndarray:
95-
"""Predict using a single decision stump."""
143+
"""
144+
Predict using a single decision stump.
145+
146+
Returns:
147+
1D array of predictions in {-1, 1}
148+
"""
96149
pred = np.ones(feature_matrix.shape[0])
150+
151+
# Apply polarity to threshold comparison
97152
if polarity == 1:
98153
pred[feature_matrix[:, feature] < threshold] = -1
99154
else:
100155
pred[feature_matrix[:, feature] > threshold] = -1
101-
return pred
156+
157+
return pred

0 commit comments

Comments
 (0)