Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions Train/config_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from Layers import DistanceWeightedMessagePassing
from Layers import LLFillSpace
from Layers import LLExtendedObjectCondensation
from Layers import DictModel,RaggedDictModel
from Layers import DictModel
from Layers import RaggedGlobalExchange
from Layers import SphereActivation
from Layers import Multi
Expand Down Expand Up @@ -163,7 +163,7 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000):
### Loop over GravNet Layers ##############################################
###########################################################################

gravnet_regs = [0.01, 0.01, 0.01]
gravnet_reg = 0.01

for i in range(GRAVNET_ITERATIONS):

Expand All @@ -189,14 +189,14 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000):
)([x, rs])

gndist = LLRegulariseGravNetSpace(
scale=gravnet_regs[i],
scale=gravnet_reg,
record_metrics=False,
name=f'regularise_gravnet_{i}')([gndist, prime_coords, gnnidx])

x_rand = random_sampling_block(
xgn, rs, gncoords, gnnidx, gndist, is_track,
reduction=6, layer_norm=True, name=f"RSU_{i}")
x_rand = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x_rand)
#x_rand = random_sampling_block(
# xgn, rs, gncoords, gnnidx, gndist, is_track,
# reduction=6, layer_norm=True, name=f"RSU_{i}")
#x_rand = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x_rand)

gndist = AverageDistanceRegularizer(
strength=1e-3,
Expand All @@ -214,7 +214,7 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000):
# x_rand = ScalarMultiply(0.1)(x_rand)
# gndist = ScalarMultiply(0.01)(gndist)
# gncoords = ScalarMultiply(0.01)(gncoords)
x = Concatenate()([x_pre, xgn, x_rand, gndist, gncoords])
x = Concatenate()([x_pre, xgn, gndist, gncoords])
x = Dense(d_shape,
name=f"dense_post_gravnet_1_iteration_{i}",
activation=DENSE_ACTIVATION,
Expand Down Expand Up @@ -270,7 +270,7 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000):

pred_beta = LLExtendedObjectCondensation(scale=1.,
use_energy_weights=True,
record_metrics=False,
record_metrics=True,
print_loss=True,
name="ExtendedOCLoss",
implementation = loss_implementation,
Expand Down Expand Up @@ -304,7 +304,7 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000):
# 'no_noise_rs': pre_processed['no_noise_rs'],
}

return RaggedDictModel(inputs=Inputs, outputs=model_outputs)
return DictModel(inputs=Inputs, outputs=model_outputs)
#return DictModel(inputs=Inputs, outputs=model_outputs)


Expand Down
3 changes: 2 additions & 1 deletion modules/GravNetLayersRagged.py
Original file line number Diff line number Diff line change
Expand Up @@ -3321,8 +3321,9 @@ def priv_call(self, inputs, training=None):
row_splits = inputs[1]
tf.assert_equal(x.shape.ndims, 2)
tf.assert_equal(row_splits.shape.ndims, 1)
#print(row_splits, row_splits[-1], tf.shape(x)[0])
if row_splits.shape[0] is not None:
tf.assert_equal(row_splits[-1], x.shape[0])
tf.assert_equal(row_splits[-1], tf.shape(x)[0])

x_coord = x
if len(inputs) == 3:
Expand Down
34 changes: 1 addition & 33 deletions modules/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,36 +953,4 @@ def __init__(self,

super(DictModel, self).__init__(inputs,outputs=outputs, *args, **kwargs)


class RaggedDictModel(tf.keras.Model):
def __init__(self,
inputs,
outputs: dict, #force to be dict
*args, **kwargs):
"""
Just forces dictionary output
"""

super(RaggedDictModel, self).__init__(inputs,outputs=outputs, *args, **kwargs)

def call(self, inputs, *args, **kwargs):
return super(RaggedDictModel, self).call(self.unpack_ragged(inputs), *args, **kwargs)

def train_step(self, inputs, *args, **kwargs):
return super(RaggedDictModel, self).train_step(self.unpack_ragged(inputs), *args, **kwargs)
#super(RaggedDictModel, self).train_step(inputs, *args, **kwargs)

def unpack_ragged(self, inputs):
output = []
for i in inputs:
print("Type of i is", type(i))
print("Hasattr", hasattr(i, "row_splits"))
if type(i) == tf.RaggedTensor:
print("Inside")
output.append((i.values, i.row_splts))
else:
output.append(i)

return output

global_layers_list['DictModel']=DictModel
global_layers_list['DictModel']=DictModel
Loading