-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Open
Description
The the example for time series classification with transformer, the the function build_model()
is defined as:
def build_model(
input_shape,
head_size,
num_heads,
ff_dim,
num_transformer_blocks,
mlp_units,
dropout=0,
mlp_dropout=0,
):
inputs = keras.Input(shape=input_shape)
x = inputs
for _ in range(num_transformer_blocks):
x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)
x = layers.GlobalAveragePooling1D(data_format="channels_first")(x) #<- This line appears to be wrong.
for dim in mlp_units:
x = layers.Dense(dim, activation="relu")(x)
x = layers.Dropout(mlp_dropout)(x)
outputs = layers.Dense(n_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)
The pooling layer is initialised as x = layers.GlobalAveragePooling1D(data_format="channels_first")(x)
. Isn't the data format however channels_last
?