Skip to content
29 changes: 16 additions & 13 deletions ctlearn/default_models/single_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ def single_cnn_model(data, model_params):

# Load neural network model
network_input_img = tf.keras.Input(shape=data.img_shape, name=f"images")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we want to support MLP only without the CNNs here or in a separate file. There is no need for an additional file I guess. @nietootein What are your thoughts about this?
I ran a quick check. It is possible to request only the parameter list without the images from dl1dh. @sahilyadav27 So I guess you can just check here if images are selected or not with data.img_shape != None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought having a separate MLP model would make sense because we might want to add the MLP model to other models as well like CNN_RNN and ResNet. So it would be better to have a common MLP model independent of the image model rather than having it inside the Single CNN model?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work now

network_input_param = tf.keras.Input(shape=data.prm_shape, name=f"parameters")
flag_prm = 0
if data.prm_shape != None:
flag_prm = 1
network_input_param = tf.keras.Input(shape=data.prm_shape, name=f"parameters")
backbone_name = model_params.get("name", "CNN") + "_block"
trainable_backbone = model_params.get("trainable_backbone", True)
pretrained_weights = model_params.get("pretrained_weights", None)
Expand All @@ -21,10 +24,12 @@ def single_cnn_model(data, model_params):
else:
sys.path.append(model_params["model_directory"])
engine_cnn_module = importlib.import_module(model_params["engine_cnn"]["module"])
engine_mlp_module = importlib.import_module(model_params["engine_mlp"]["module"])
engine_cnn = getattr(engine_cnn_module, model_params["engine_cnn"]["function"])
engine_mlp = getattr(engine_mlp_module, model_params["engine_mlp"]["function"])

if flag_prm == 1:
engine_mlp_module = importlib.import_module(model_params["engine_mlp"]["module"])
engine_mlp = getattr(engine_mlp_module, model_params["engine_mlp"]["function"])
engine_output_mlp = engine_mlp(network_input_param, params=model_params, name=backbone_name)
output_mlp = tf.keras.layers.Flatten()(engine_output_mlp)
# The original ResNet implementation use this padding, but we pad the images in the ImageMapper.
# x = tf.pad(telescope_data, tf.constant([[3, 3], [3, 3]]), name='conv1_pad')
init_layer = model_params.get("init_layer", False)
Expand All @@ -45,15 +50,13 @@ def single_cnn_model(data, model_params):
)(network_input_img)

engine_output_cnn = engine_cnn(network_input_img, params=model_params, name=backbone_name)
engine_output_mlp = engine_mlp(network_input_param, params=model_params, name=backbone_name)

output_cnn = tf.keras.layers.GlobalAveragePooling2D(
name=backbone_name + "_global_avgpool"
)(engine_output_cnn)
output_mlp = tf.keras.layers.Flatten()(engine_output_mlp)
concat = tf.keras.layers.Concatenate()([output_cnn, output_mlp])

singlecnn_model = tf.keras.Model(inputs=[network_input_img, network_input_param], outputs = [concat], name=backbone_name)


return singlecnn_model, [network_input_img, network_input_param]
if flag_prm == 1:
concat = tf.keras.layers.Concatenate()([output_cnn, output_mlp])
singlecnn_model = tf.keras.Model(inputs=[network_input_img, network_input_param], outputs = [concat], name=backbone_name)
return singlecnn_model, [network_input_img, network_input_param]
else:
singlecnn_model = tf.keras.Model(network_input_img, output_cnn, name=backbone_name)
return singlecnn_model, [network_input_img]