Skip to content

Commit 00a17fd

Browse files
Update adaptive_resonance_theory.py
1 parent acfe9f0 commit 00a17fd

File tree

1 file changed

+49
-4
lines changed

1 file changed

+49
-4
lines changed

neural_network/adaptive_resonance_theory.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,34 @@
1+
"""
2+
adaptive_resonance_theory.py
3+
4+
This module implements the Adaptive Resonance Theory 1 (ART1) model, a type
5+
of neural network designed for unsupervised learning and clustering of binary
6+
input data. The ART1 algorithm continuously learns to categorize inputs based
7+
on their similarity while preserving previously learned categories. This is
8+
achieved through a vigilance parameter that controls the strictness of
9+
category matching, allowing for flexible and adaptive clustering.
10+
11+
ART1 is particularly useful in applications where it is critical to learn new
12+
patterns without forgetting previously learned ones, making it suitable for
13+
real-time data clustering and pattern recognition tasks.
14+
15+
References:
16+
1. Carpenter, G. A., & Grossberg, S. (1987). "A Adaptive Resonance Theory."
17+
In: Neural Networks for Pattern Recognition, Oxford University Press,
18+
pp. 194–203.
19+
2. Carpenter, G. A., & Grossberg, S. (1988). "The ART of Adaptive Pattern
20+
Recognition by a Self-Organizing Neural Network." IEEE Transactions on
21+
Neural Networks, 1(2), 115-130. DOI: 10.1109/TNN.1988.82656
22+
23+
"""
24+
125
import numpy as np
226

327

428
class ART1:
529
"""
630
Adaptive Resonance Theory 1 (ART1) model for binary data clustering.
731
8-
...
9-
1032
Attributes:
1133
num_features (int): Number of features in the input data.
1234
vigilance (float): Threshold for similarity that determines whether
@@ -50,7 +72,7 @@ def _similarity(self, weight_vector: np.ndarray, input_vector: np.ndarray) -> fl
5072
or len(input_vector) != self.num_features
5173
):
5274
raise ValueError(
53-
"Both weight_vector and input_vector must have certain number."
75+
"Both weight_vector and input_vector must have the same number of features."
5476
)
5577

5678
return np.dot(weight_vector, input_vector) / self.num_features
@@ -78,6 +100,29 @@ def _learn(
78100
"""
79101
return learning_rate * x + (1 - learning_rate) * w
80102

103+
def train(self, data: np.ndarray) -> None:
104+
"""
105+
Train the ART1 model on the provided data.
106+
107+
Args:
108+
data (np.ndarray): Input data for training.
109+
110+
Returns:
111+
None
112+
"""
113+
for x in data:
114+
# Predict the cluster for the input data point
115+
cluster_index = self.predict(x)
116+
117+
if cluster_index == -1: # No existing cluster matches
118+
# Create a new cluster with the current input
119+
self.weights.append(x)
120+
else:
121+
# Update the existing cluster's weights
122+
self.weights[cluster_index] = self._learn(
123+
self.weights[cluster_index], x
124+
)
125+
81126
def predict(self, x: np.ndarray) -> int:
82127
"""
83128
Assign data to the closest cluster.
@@ -107,7 +152,7 @@ def art1_example() -> None:
107152
"""
108153
Example function demonstrating the usage of the ART1 model.
109154
110-
This function creates dataset, trains ART1 model, and prints assigned clusters.
155+
This function creates a dataset, trains the ART1 model, and prints assigned clusters.
111156
112157
Examples:
113158
>>> art1_example()

0 commit comments

Comments
 (0)