Skip to content

Commit aac1eab

Browse files
committed
Updated Data Loader, Model Architecture
1 parent 75b1c58 commit aac1eab

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

utils/data_loader.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,9 @@ def augment_batch(self, image, label) -> tuple:
117117
if tf.random.normal([1]) < 0:
118118
image = tf.image.random_brightness(image, 0.2)
119119
if self.NUM_CHANNELS == 3 and tf.random.normal([1]) < 0:
120-
image = tf.image.random_hue(image, 0.3)
120+
image = tf.image.random_hue(image, 0.1)
121121
if self.NUM_CHANNELS == 3 and tf.random.normal([1]) < 0:
122122
image = tf.image.random_saturation(image, 0, 15)
123-
if tf.random.normal([1]) < 0:
124-
image = tf.image.random_flip_up_down(image)
125123

126124
image = tf.image.random_flip_left_right(image)
127125
image = tf.image.random_jpeg_quality(image, 10, 100)
@@ -197,9 +195,7 @@ def visualize_batch(self, augment=True) -> None:
197195

198196

199197
if __name__ == "__main__":
200-
dataset_root_dir = (
201-
"/home/ani/Documents/pycodes/Dataset/mit-indoor-scenes/indoorCVPR_09/Images"
202-
)
198+
dataset_root_dir = "/home/ani/Documents/pycodes/Dataset/gender/Training/"
203199
data_loader = ImageClassificationDataLoader(
204200
data_dir=dataset_root_dir,
205201
image_dims=(512, 512),

utils/model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def init_network(self):
180180
backbone=self.backbone, input_shape=self.input_shape, classes=self.classes
181181
)
182182
x = base.output
183+
x = tf.keras.layers.Conv2D(64, (1, 1), activation="relu")(x)
183184
x = tf.keras.layers.Dropout(0.5)(x)
184-
x = tf.keras.layers.AveragePooling2D(pool_size=(5, 5))(x)
185185
x = tf.keras.layers.Flatten()(x)
186186
x = tf.keras.layers.Dense(units=128, activation="relu")(x)
187187
x = tf.keras.layers.Dropout(0.5)(x)
@@ -191,10 +191,10 @@ def init_network(self):
191191
self.model = tf.keras.models.Model(inputs=[base.input], outputs=[x])
192192
return self.model
193193

194-
def get_model_summary(self):
194+
def print_model_summary(self):
195195
if self.model is None:
196196
return f"Please use the `init_network` Method to the model first"
197-
return self.model.summary()
197+
self.model.summary()
198198

199199
def get_model(self):
200200
if self.history is None:
@@ -248,6 +248,7 @@ def train(
248248
val_generator=None,
249249
val_steps=None,
250250
epochs=500,
251+
print_summary=False,
251252
):
252253

253254
if self.metrics is None:
@@ -264,6 +265,10 @@ def train(
264265
optimizer=self.optimizer,
265266
metrics=self.metrics,
266267
)
268+
269+
if print_summary:
270+
self.model.summary()
271+
267272
self.history = self.model.fit(
268273
x=train_generator,
269274
epochs=epochs,

0 commit comments

Comments
 (0)