-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathkeras_profile.py
More file actions
27 lines (20 loc) · 865 Bytes
/
keras_profile.py
File metadata and controls
27 lines (20 loc) · 865 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import tensorflow as tf
import keras.backend as K
from models.firenet_tf import firenet_tf
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops # Prints the "flops" of the model.
model = firenet_tf(input_shape=(64, 64, 3))
model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
print(get_flops(model))
model = firenet_tf(input_shape=(64, 64, 3))
model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
print(get_flops(model))