Skip to content

Commit 8bb0a3c

Browse files
committed
Fix trainer auto build when program is unbuilt
1 parent f73d8bd commit 8bb0a3c

File tree

5 files changed

+31
-20
lines changed

5 files changed

+31
-20
lines changed

coverage-badge.svg

Lines changed: 1 addition & 1 deletion
Loading

synalinks/src/backend/common/stateless_scope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747

4848
self.collect_rewards = collect_rewards
4949
self.initialize_variables = initialize_variables
50-
self.evals = []
50+
self.rewards = []
5151
self.state_mapping = {}
5252
state_mapping = state_mapping or {}
5353
for k, v in state_mapping:
@@ -77,7 +77,7 @@ def __enter__(self):
7777
return self
7878

7979
def add_eval(self, eval):
80-
self.evals.append(eval)
80+
self.rewards.append(eval)
8181

8282
def add_update(self, update):
8383
variable, value = update

synalinks/src/modules/module.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,19 @@ def input_spec(self):
191191
def input_spec(self, value):
192192
self._input_spec = value
193193

194+
@classmethod
195+
def get_arity(cls):
196+
# Inspect the call method to get the number of parameters
197+
sig = inspect.signature(cls.call)
198+
# Exclude 'self' and 'training' from the parameter count
199+
return len(
200+
[
201+
param
202+
for param in sig.parameters.values()
203+
if param.name not in ["self", "training"]
204+
]
205+
)
206+
194207
@python_utils.default
195208
async def build(self, input_schema):
196209
self._check_super_called()

synalinks/src/trainers/trainer.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from synalinks.src import callbacks as callbacks_module
1313
from synalinks.src import metrics as metrics_module
1414
from synalinks.src import optimizers as optimizers_module
15-
from synalinks.src import tree
1615
from synalinks.src.backend.common import numpy
1716
from synalinks.src.saving import serialization_lib
1817
from synalinks.src.trainers.compile_utils import CompileMetrics
@@ -474,7 +473,10 @@ async def fit(
474473
# Build the model on one batch of data.
475474
for _, data in epoch_iterator:
476475
data_batch = data[0]
477-
self._symbolic_build(data_batch)
476+
self._auto_build(
477+
iterator=epoch_iterator,
478+
data_batch=data_batch,
479+
)
478480
break
479481
epoch_iterator.reset()
480482

@@ -630,7 +632,10 @@ async def evaluate(
630632
# Build the model on one batch of data.
631633
for _, data in epoch_iterator:
632634
data_batch = data[0]
633-
self._symbolic_build(data_batch)
635+
self._auto_build(
636+
iterator=epoch_iterator,
637+
data_batch=data_batch,
638+
)
634639
break
635640
epoch_iterator.reset()
636641

@@ -965,7 +970,7 @@ def _assert_compile_called(self, method_name=None):
965970
msg += f"calling `{method_name}()`."
966971
raise ValueError(msg)
967972

968-
def _symbolic_build(self, iterator=None, data_batch=None):
973+
def _auto_build(self, iterator=None, data_batch=None):
969974
program_unbuilt = not all(module.built for module in self._flatten_modules())
970975
compile_metrics_unbuilt = (
971976
self._compile_metrics is not None and not self._compile_metrics.built
@@ -975,26 +980,17 @@ def _symbolic_build(self, iterator=None, data_batch=None):
975980
)
976981
optimizer_unbuilt = self.optimizer is not None and not self.optimizer.built
977982
if program_unbuilt or compile_metrics_unbuilt or compile_reward_unbuilt:
978-
# Create symbolic data_models matching an input batch.
979-
980-
def to_symbolic_input(v):
981-
if v is None:
982-
return None
983-
return backend.SymbolicDataModel(schema=v.get_schema())
984-
985983
if data_batch is None:
986984
for _, data_or_iterator in iterator:
987985
if isinstance(data_or_iterator, (list, tuple)):
988986
data_batch = data_or_iterator[0]
989987
else:
990988
data_batch = next(data_or_iterator)
991989
break
992-
data_batch = tree.map_structure(to_symbolic_input, data_batch)
993990
(x, y) = data_batch
994-
# Build all program state with `backend.compute_output_spec`.
995991
try:
996992
y_pred = asyncio.get_event_loop().run_until_complete(
997-
backend.compute_output_spec(self, x, training=False)
993+
self.predict_on_batch(x, training=False)
998994
)
999995
except Exception as e:
1000996
raise RuntimeError(
@@ -1022,7 +1018,7 @@ def to_symbolic_input(v):
10221018
# Build `CompileReward` state with `backend.compute_output_spec`.
10231019
asyncio.get_event_loop().run_until_complete(
10241020
backend.compute_output_spec(
1025-
self._compute_reward,
1021+
self.compute_reward,
10261022
x,
10271023
y,
10281024
y_pred,
@@ -1031,7 +1027,9 @@ def to_symbolic_input(v):
10311027
)
10321028
if optimizer_unbuilt:
10331029
# Build optimizer
1034-
self.optimizer.build(self.trainable_variables)
1030+
asyncio.get_event_loop().run_until_complete(
1031+
self.optimizer.build(self.trainable_variables)
1032+
)
10351033
self._post_build()
10361034

10371035
def _assert_compile_called(self, method_name=None):

synalinks/src/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from synalinks.src.api_export import synalinks_export
44

55
# Unique source of truth for the version number.
6-
__version__ = "0.2.024"
6+
__version__ = "0.2.025"
77

88

99
@synalinks_export("synalinks.version")

0 commit comments

Comments
 (0)