Skip to content

Commit 430a22c

Browse files
committed
docs: Add detailed comments to model training notebook
Added comprehensive comments explaining the CNN architecture, model compilation, training process, evaluation, and visualization steps to improve code documentation.
1 parent cc491ad commit 430a22c

File tree

1 file changed

+42
-52
lines changed

1 file changed

+42
-52
lines changed

Model/Training.ipynb

Lines changed: 42 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18+
"# Import pickle for loading preprocessed data\n",
1819
"import pickle"
1920
]
2021
},
@@ -25,8 +26,9 @@
2526
"metadata": {},
2627
"outputs": [],
2728
"source": [
28-
"X = pickle.load(open('X.pkl', 'rb'))\n",
29-
"y = pickle.load(open('y.pkl', 'rb'))"
29+
"# Load the preprocessed features and labels from pickle files\n",
30+
"X = pickle.load(open('X.pkl', 'rb')) # Load image features\n",
31+
"y = pickle.load(open('y.pkl', 'rb')) # Load corresponding labels"
3032
]
3133
},
3234
{
@@ -36,6 +38,7 @@
3638
"metadata": {},
3739
"outputs": [],
3840
"source": [
41+
"# Display the features array\n",
3942
"X"
4043
]
4144
},
@@ -46,6 +49,7 @@
4649
"metadata": {},
4750
"outputs": [],
4851
"source": [
52+
"# Display the labels array\n",
4953
"y"
5054
]
5155
},
@@ -56,6 +60,7 @@
5660
"metadata": {},
5761
"outputs": [],
5862
"source": [
63+
"# Normalize pixel values to range [0, 1] for better model training\n",
5964
"X = X/255"
6065
]
6166
},
@@ -66,6 +71,7 @@
6671
"metadata": {},
6772
"outputs": [],
6873
"source": [
74+
"# Display normalized features\n",
6975
"X"
7076
]
7177
},
@@ -76,6 +82,7 @@
7682
"metadata": {},
7783
"outputs": [],
7884
"source": [
85+
"# Check the current shape of the features array\n",
7986
"X.shape"
8087
]
8188
},
@@ -88,6 +95,8 @@
8895
},
8996
"outputs": [],
9097
"source": [
98+
"# Reshape data to include channel dimension for CNN input\n",
99+
"# Shape: (samples, height, width, channels) where 1 = grayscale\n",
91100
"X = X.reshape(-1, 224, 224, 1)"
92101
]
93102
},
@@ -98,6 +107,7 @@
98107
"metadata": {},
99108
"outputs": [],
100109
"source": [
110+
"# Verify the new shape after reshaping\n",
101111
"X.shape"
102112
]
103113
},
@@ -108,8 +118,9 @@
108118
"metadata": {},
109119
"outputs": [],
110120
"source": [
111-
"from keras.models import Sequential\n",
112-
"from keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout"
121+
"# Import Keras libraries for building the CNN model\n",
122+
"from keras.models import Sequential # For creating sequential model\n",
123+
"from keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout # Layer types"
113124
]
114125
},
115126
{
@@ -119,7 +130,7 @@
119130
"metadata": {},
120131
"outputs": [],
121132
"source": [
122-
"# Initialising the CNN\n",
133+
"# Initialize the Convolutional Neural Network\n",
123134
"model = Sequential()"
124135
]
125136
},
@@ -130,27 +141,38 @@
130141
"metadata": {},
131142
"outputs": [],
132143
"source": [
133-
"# Step 1 - Convolution\n",
144+
"# Step 1 - First Convolutional Layer\n",
145+
"# 64 filters of size 3x3, ReLU activation function\n",
134146
"model.add(Conv2D(64,(3,3), activation='relu'))\n",
135-
"# Step 2 - Pooling\n",
147+
"\n",
148+
"# Step 2 - First Pooling Layer\n",
149+
"# 2x2 max pooling to reduce spatial dimensions\n",
136150
"model.add(MaxPooling2D((2,2)))\n",
137151
"\n",
138-
"# Adding a second convolutional layer\n",
152+
"# Adding a second convolutional block\n",
153+
"# Another 64 filters of 3x3 for deeper feature extraction\n",
139154
"model.add(Conv2D(64,(3,3), activation='relu'))\n",
140155
"model.add(MaxPooling2D((2,2)))\n",
141156
"\n",
142-
"# Adding a third convolutional layer\n",
157+
"# Adding a third convolutional block\n",
158+
"# Final convolutional layer for complex pattern recognition\n",
143159
"model.add(Conv2D(64,(3,3), activation='relu'))\n",
144160
"model.add(MaxPooling2D((2,2)))\n",
161+
"\n",
162+
"# Dropout layer to prevent overfitting by randomly dropping 40% of connections\n",
145163
"model.add(Dropout(0.4))\n",
146164
"\n",
147165
"# Step 3 - Flattening\n",
166+
"# Convert 2D feature maps to 1D feature vector\n",
148167
"model.add(Flatten())\n",
149168
"\n",
150-
"# Step 4 - Full Connection\n",
169+
"# Step 4 - Full Connection (Hidden Layer)\n",
170+
"# Dense layer with 128 neurons for learning complex patterns\n",
151171
"model.add(Dense(128, input_shape = X.shape[1:], activation = 'relu'))\n",
152172
"\n",
153173
"# Step 5 - Output Layer\n",
174+
"# 3 neurons for 3 classes (Mammootty, Mohanlal, Unknown)\n",
175+
"# Softmax activation for probability distribution across classes\n",
154176
"model.add(Dense(3, activation= 'softmax'))"
155177
]
156178
},
@@ -160,10 +182,7 @@
160182
"id": "2f239525",
161183
"metadata": {},
162184
"outputs": [],
163-
"source": [
164-
"# Compiling the CNN\n",
165-
"model.compile(optimizer = 'adam', loss='sparse_categorical_crossentropy', metrics = ['accuracy'])"
166-
]
185+
"source": "# Compile the CNN model\n# Adam optimizer: adaptive learning rate optimization algorithm\n# sparse_categorical_crossentropy: loss function for multi-class classification with integer labels\n# accuracy: metric to monitor during training\nmodel.compile(optimizer = 'adam', loss='sparse_categorical_crossentropy', metrics = ['accuracy'])"
167186
},
168187
{
169188
"cell_type": "code",
@@ -173,92 +192,63 @@
173192
"scrolled": true
174193
},
175194
"outputs": [],
176-
"source": [
177-
"hist = model.fit(X, y, epochs=10, validation_split=0.2)"
178-
]
195+
"source": "# Train the model\n# epochs=10: train for 10 complete passes through the dataset\n# validation_split=0.2: use 20% of data for validation during training\nhist = model.fit(X, y, epochs=10, validation_split=0.2)"
179196
},
180197
{
181198
"cell_type": "code",
182199
"execution_count": null,
183200
"id": "cc34347a",
184201
"metadata": {},
185202
"outputs": [],
186-
"source": [
187-
"model.summary()"
188-
]
203+
"source": "# Display model architecture summary\n# Shows layers, output shapes, and trainable parameters\nmodel.summary()"
189204
},
190205
{
191206
"cell_type": "code",
192207
"execution_count": null,
193208
"id": "1f65bfca",
194209
"metadata": {},
195210
"outputs": [],
196-
"source": [
197-
"#to know accuracy of model\n",
198-
"scores = model.evaluate(X,y,verbose=0)\n",
199-
"print(\"Accuracy: %.2f%%\" % (scores[1]*100))"
200-
]
211+
"source": "# Evaluate model accuracy on the training data\nscores = model.evaluate(X,y,verbose=0)\nprint(\"Accuracy: %.2f%%\" % (scores[1]*100))"
201212
},
202213
{
203214
"cell_type": "code",
204215
"execution_count": null,
205216
"id": "582c9ac3",
206217
"metadata": {},
207218
"outputs": [],
208-
"source": [
209-
"X.shape"
210-
]
219+
"source": "# Verify the shape of features array\nX.shape"
211220
},
212221
{
213222
"cell_type": "code",
214223
"execution_count": null,
215224
"id": "466b56d3",
216225
"metadata": {},
217226
"outputs": [],
218-
"source": [
219-
"import matplotlib.pyplot as plt"
220-
]
227+
"source": "# Import matplotlib for visualizing training metrics\nimport matplotlib.pyplot as plt"
221228
},
222229
{
223230
"cell_type": "code",
224231
"execution_count": null,
225232
"id": "fb0ebddb",
226233
"metadata": {},
227234
"outputs": [],
228-
"source": [
229-
"fig = plt.figure()\n",
230-
"plt.plot(hist.history['loss'],color='teal',label='loss')\n",
231-
"plt.plot(hist.history['val_loss'],color='orange',label='val_loss')\n",
232-
"plt.suptitle('Loss',fontsize=20)\n",
233-
"plt.legend(loc=\"upper left\")\n",
234-
"plt.show"
235-
]
235+
"source": "# Plot training and validation loss over epochs\nfig = plt.figure()\nplt.plot(hist.history['loss'],color='teal',label='loss') # Training loss\nplt.plot(hist.history['val_loss'],color='orange',label='val_loss') # Validation loss\nplt.suptitle('Loss',fontsize=20)\nplt.legend(loc=\"upper left\")\nplt.show"
236236
},
237237
{
238238
"cell_type": "code",
239239
"execution_count": null,
240240
"id": "36208bdd",
241241
"metadata": {},
242242
"outputs": [],
243-
"source": [
244-
"fig = plt.figure()\n",
245-
"plt.plot(hist.history['accuracy'],color='teal',label='accuracy')\n",
246-
"plt.plot(hist.history['val_accuracy'],color='orange',label='val_accuracy')\n",
247-
"plt.suptitle('Accuracy',fontsize=20)\n",
248-
"plt.legend(loc=\"upper left\")\n",
249-
"plt.show"
250-
]
243+
"source": "# Plot training and validation accuracy over epochs\nfig = plt.figure()\nplt.plot(hist.history['accuracy'],color='teal',label='accuracy') # Training accuracy\nplt.plot(hist.history['val_accuracy'],color='orange',label='val_accuracy') # Validation accuracy\nplt.suptitle('Accuracy',fontsize=20)\nplt.legend(loc=\"upper left\")\nplt.show"
251244
},
252245
{
253246
"cell_type": "code",
254247
"execution_count": null,
255248
"id": "9418fe74",
256249
"metadata": {},
257250
"outputs": [],
258-
"source": [
259-
"# save the model\n",
260-
"model.save('3-class-improved.h5')"
261-
]
251+
"source": "# Save the trained model to disk\n# Saves model architecture, weights, and optimizer state\nmodel.save('3-class-improved.h5')"
262252
}
263253
],
264254
"metadata": {
@@ -282,4 +272,4 @@
282272
},
283273
"nbformat": 4,
284274
"nbformat_minor": 5
285-
}
275+
}

0 commit comments

Comments
 (0)