Skip to content

Commit cdf8b1f

Browse files
Override gnn_model_base trainable_variables
This patch provides a custom implementation of trainable_variables in gnn_model_base. Theoretically this should have been made unnecessary by \#337, but the interanl version of Tensorflow refuses to recurse into the modules inside of the GraphNetworkLayer classes. This patch fixes that by just returning the values regardless. Reviewers: orodley, virajbshah, ondrasej Reviewed By: ondrasej Pull Request: #341
1 parent 0e93371 commit cdf8b1f

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

gematria/granite/python/gnn_model_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ def __init__(
210210
self._graph_module_residual_connections = graph_module_residual_connections
211211
self._graph_module_layer_normalization = graph_module_layer_normalization
212212

213+
@property
214+
def trainable_variables(self):
215+
trainable_vars = set(var.ref() for var in super().trainable_variables)
216+
for layer in self._graph_network:
217+
layer_vars = [var.ref() for var in layer.module.trainable_variables]
218+
trainable_vars.update(layer_vars)
219+
return tuple(var.deref() for var in trainable_vars)
220+
213221
def initialize(self):
214222
super().initialize()
215223
self._graph_network = self._create_graph_network_modules()

0 commit comments

Comments
 (0)