Skip to content

Commit 7538cd7

Browse files
committed
Fix CNN Algo
1 parent 8c95027 commit 7538cd7

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

machine_learning/cnn.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
"""
1313

1414
import numpy as np
15-
from typing import Tuple
1615

1716

1817
class SimpleCNN:
19-
def __init__(self, input_shape: Tuple[int, int, int], num_classes: int) -> None:
18+
def __init__(self, input_shape: tuple[int, int, int], num_classes: int) -> None:
2019
"""
2120
Initialize a simple CNN model.
2221
@@ -26,42 +25,43 @@ def __init__(self, input_shape: Tuple[int, int, int], num_classes: int) -> None:
2625
"""
2726
self.input_shape = input_shape
2827
self.num_classes = num_classes
29-
self.filters = np.random.randn(8, input_shape[0], 3, 3) * 0.1 # 8 filters
30-
self.fc_weights = np.random.randn(8 * 26 * 26, num_classes) * 0.1
28+
rng = np.random.default_rng()
29+
self.filters = rng.normal(0, 0.1, size=(8, input_shape[0], 3, 3)) # 8 filters
30+
self.fc_weights = rng.normal(0, 0.1, size=(8 * 26 * 26, num_classes))
3131

32-
def relu(self, x: np.ndarray) -> np.ndarray:
33-
"""Apply ReLU activation."""
34-
return np.maximum(0, x)
32+
def relu(self, feature_map: np.ndarray) -> np.ndarray:
33+
"""Apply ReLU activation to the feature map."""
34+
return np.maximum(0, feature_map)
3535

36-
def convolve(self, x: np.ndarray, filters: np.ndarray) -> np.ndarray:
37-
"""Apply convolution operation."""
38-
batch, height, width = x.shape
36+
def convolve(self, input_tensor: np.ndarray, filters: np.ndarray) -> np.ndarray:
37+
"""Apply convolution operation to the input tensor."""
38+
_, height, width = input_tensor.shape
3939
num_filters, _, fh, fw = filters.shape
4040
output = np.zeros((num_filters, height - fh + 1, width - fw + 1))
4141

4242
for f in range(num_filters):
4343
for i in range(height - fh + 1):
4444
for j in range(width - fw + 1):
45-
region = x[:, i:i + fh, j:j + fw]
45+
region = input_tensor[:, i:i + fh, j:j + fw]
4646
output[f, i, j] = np.sum(region * filters[f])
4747
return output
4848

49-
def flatten(self, x: np.ndarray) -> np.ndarray:
50-
"""Flatten the feature map."""
51-
return x.reshape(-1)
49+
def flatten(self, feature_map: np.ndarray) -> np.ndarray:
50+
"""Flatten the feature map into a 1D array."""
51+
return feature_map.reshape(-1)
5252

53-
def forward(self, x: np.ndarray) -> np.ndarray:
53+
def forward(self, input_tensor: np.ndarray) -> np.ndarray:
5454
"""
5555
Forward pass through the CNN.
5656
5757
Args:
58-
x: Input image of shape (channels, height, width)
58+
input_tensor: Input image of shape (channels, height, width)
5959
6060
Returns:
6161
Output logits of shape (num_classes,)
6262
"""
63-
conv_out = self.convolve(x, self.filters)
63+
conv_out = self.convolve(input_tensor, self.filters)
6464
activated = self.relu(conv_out)
6565
flattened = self.flatten(activated)
6666
logits = flattened @ self.fc_weights
67-
return logits
67+
return logits

0 commit comments

Comments
 (0)