Skip to content

Commit e1a90cf

Browse files
[π˜€π—½π—Ώ] changes to main this commit is based on (#337)
We cannot make the dataclass frozen when doing this, but this enables TF to automatically find trainable variables within GraphNetworkLayer objects which means we can get rid of the somewhat hacky _get_trainable_variables function that subclasses were supposed to override. This successfully trains models that would otherwise fail to converge if none of the graph layers were trainable. This closes #323.
1 parent 9ca566a commit e1a90cf

File tree

2 files changed

+3
-13
lines changed

2 files changed

+3
-13
lines changed

β€Žgematria/granite/python/gnn_model_base.pyβ€Ž

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
import tf_keras
3131

3232

33-
@dataclasses.dataclass(frozen=True)
34-
class GraphNetworkLayer:
33+
@dataclasses.dataclass
34+
class GraphNetworkLayer(tf.Module):
3535
"""Specifies one segment of the pipeline of the graph network.
3636
3737
Each segment consists of a graph network module, i.e. a Sonnet module that
@@ -290,12 +290,6 @@ def initialize(self):
290290
tf_keras.layers.LayerNormalization(name=globals_layer_norm_name)
291291
)
292292

293-
def _get_trainable_variables(self):
294-
trainable_variables = list(super()._get_trainable_variables())
295-
for layer in self._graph_network:
296-
trainable_variables.extend(layer.module.trainable_variables)
297-
return trainable_variables
298-
299293
# @Override
300294
def _forward(self, feed_dict):
301295
graph_tuple_outputs = self._execute_graph_network(feed_dict)

β€Žgematria/model/python/model_base.pyβ€Ž

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,9 +1331,6 @@ def compute_loss_tensor(self, schedule: FeedDict):
13311331
)
13321332
)
13331333

1334-
def _get_trainable_variables(self):
1335-
return self.trainable_variables
1336-
13371334
def train_batch(
13381335
self,
13391336
schedule: FeedDict,
@@ -1367,11 +1364,10 @@ def train_batch(
13671364
for variable in self._variable_groups.get(variable_group)
13681365
)
13691366

1370-
trainable_variables = self._get_trainable_variables()
13711367
variables = (
13721368
[variable.deref() for variable in requested_variables]
13731369
if requested_variables
1374-
else trainable_variables
1370+
else self.trainable_variables
13751371
)
13761372

13771373
grads = tape.gradient(loss_tensor, variables)

0 commit comments

Comments
Β (0)