Skip to content

Commit 93402f6

Browse files
committed
Update requirements and training scripts
1 parent 736e300 commit 93402f6

File tree

4 files changed

+19
-4
lines changed

4 files changed

+19
-4
lines changed

examples/grnet/task.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,20 @@ def train_and_evaluate(args):
5959
)
6060

6161
grnet = model.create_model(args.learning_rate, num_classes=AFEW_CLASSES)
62+
63+
os.makedirs(args.job_dir, exist_ok=True)
6264
checkpoint_path = os.path.join(args.job_dir, "afew-grnet.ckpt")
6365
cp_callback = tf.keras.callbacks.ModelCheckpoint(
6466
filepath=checkpoint_path, save_weights_only=True, verbose=1
6567
)
68+
log_dir = os.path.join(args.job_dir, "logs")
69+
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
70+
6671
grnet.fit(
6772
train_dataset,
6873
epochs=args.num_epochs,
6974
validation_data=val_dataset,
70-
callbacks=[cp_callback],
75+
callbacks=[cp_callback, tb_callback],
7176
)
7277
_, acc = grnet.evaluate(val_dataset, verbose=2)
7378
print("Final accuracy: {}%".format(acc * 100))

examples/lienet/task.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,20 @@ def train_and_evaluate(args):
7474
)
7575

7676
lienet = model.create_model(args.learning_rate, num_classes=G3D_CLASSES)
77+
78+
os.makedirs(args.job_dir, exist_ok=True)
7779
checkpoint_path = os.path.join(args.job_dir, "g3d-lienet.ckpt")
7880
cp_callback = tf.keras.callbacks.ModelCheckpoint(
7981
filepath=checkpoint_path, save_weights_only=True, verbose=1
8082
)
83+
log_dir = os.path.join(args.job_dir, "logs")
84+
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
85+
8186
lienet.fit(
8287
train_dataset,
8388
epochs=args.num_epochs,
8489
validation_data=val_dataset,
85-
callbacks=[cp_callback],
90+
callbacks=[cp_callback, tb_callback],
8691
)
8792
_, acc = lienet.evaluate(val_dataset, verbose=2)
8893
print("Final accuracy: {}%".format(acc * 100))

examples/spdnet/task.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,20 @@ def train_and_evaluate(args):
5959
)
6060

6161
spdnet = model.create_model(args.learning_rate, num_classes=AFEW_CLASSES)
62+
63+
os.makedirs(args.job_dir, exist_ok=True)
6264
checkpoint_path = os.path.join(args.job_dir, "afew-spdnet.ckpt")
6365
cp_callback = tf.keras.callbacks.ModelCheckpoint(
6466
filepath=checkpoint_path, save_weights_only=True, verbose=1
6567
)
68+
log_dir = os.path.join(args.job_dir, "logs")
69+
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)
70+
6671
spdnet.fit(
6772
train_dataset,
6873
epochs=args.num_epochs,
6974
validation_data=val_dataset,
70-
callbacks=[cp_callback],
75+
callbacks=[cp_callback, tb_callback],
7176
)
7277
_, acc = spdnet.evaluate(val_dataset, verbose=2)
7378
print("Final accuracy: {}%".format(acc * 100))

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
tensorflow>=2.0.0
1+
tensorflow>=2.3.0

0 commit comments

Comments
 (0)