-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathTPN.py
More file actions
29 lines (24 loc) · 893 Bytes
/
TPN.py
File metadata and controls
29 lines (24 loc) · 893 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
def TPN_Model(input_shape, model_name="base_model"):
inputs = tf.keras.Input(shape=input_shape, name='input')
x = inputs
x = tf.keras.layers.Conv1D(
32, 24,
activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(l=1e-4)
)(x)
x = tf.keras.layers.Dropout(0.1)(x)
x = tf.keras.layers.Conv1D(
64, 16,
activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(l=1e-4),
)(x)
x = tf.keras.layers.Dropout(0.1)(x)
x = tf.keras.layers.Conv1D(
96, 8,
activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(l=1e-4),
)(x)
x = tf.keras.layers.Dropout(0.1)(x)
x = tf.keras.layers.GlobalMaxPool1D(data_format='channels_last', name='global_max_pooling1d')(x)
return tf.keras.Model(inputs, x, name=model_name)