Skip to content

Commit f73d8bd

Browse files
committed
Robustify the Branch to allow branches to be None (using the logical and)
1 parent ea180ba commit f73d8bd

File tree

3 files changed

+18
-58
lines changed

3 files changed

+18
-58
lines changed

coverage-badge.svg

Lines changed: 1 addition & 1 deletion
Loading

synalinks/src/modules/core/branch.py

Lines changed: 16 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -148,38 +148,18 @@ async def call(self, inputs, training=False):
148148
)
149149
choice = decision.get("choice")
150150
outputs = []
151+
if self.inject_decision:
152+
inputs = await ops.concat(
153+
inputs,
154+
decision,
155+
name=self.name + "_inputs_with_decision",
156+
)
151157
for label, module in self.branches.items():
152158
if label == choice:
153159
if module:
154-
if self.inject_decision and self.return_decision:
160+
if self.return_decision:
155161
outputs.append(
156-
await ops.concat(
157-
decision,
158-
await module(
159-
await ops.concat(
160-
inputs,
161-
decision,
162-
name=self.name + "_inputs_with_decision",
163-
),
164-
training=training,
165-
),
166-
name=self.name + "_with_decision",
167-
)
168-
)
169-
elif self.inject_decision and not self.return_decision:
170-
outputs.append(
171-
await module(
172-
await ops.concat(
173-
inputs,
174-
decision,
175-
name=self.name + "_inputs_with_decision",
176-
),
177-
training=training,
178-
)
179-
)
180-
elif not self.inject_decision and self.return_decision:
181-
outputs.append(
182-
await ops.concat(
162+
await ops.logical_and(
183163
decision,
184164
await module(
185165
inputs,
@@ -207,36 +187,16 @@ async def compute_output_spec(self, inputs, training=False):
207187
inputs,
208188
training=training,
209189
)
190+
if self.inject_decision:
191+
inputs = await ops.concat(
192+
inputs,
193+
decision,
194+
name=self.name + "_inputs_with_decision",
195+
)
210196
for module in self.branches.values():
211-
if self.inject_decision and self.return_decision:
212-
outputs.append(
213-
await ops.concat(
214-
decision,
215-
await module(
216-
await ops.concat(
217-
inputs,
218-
decision,
219-
name=self.name + "_inputs_with_decision",
220-
),
221-
training=training,
222-
),
223-
name=self.name + "_with_decision",
224-
)
225-
)
226-
elif self.inject_decision and not self.return_decision:
227-
outputs.append(
228-
await module(
229-
await ops.concat(
230-
inputs,
231-
decision,
232-
name=self.name + "_inputs_with_decision",
233-
),
234-
training=training,
235-
)
236-
)
237-
elif not self.inject_decision and self.return_decision:
197+
if self.return_decision:
238198
outputs.append(
239-
await ops.concat(
199+
await ops.logical_and(
240200
decision,
241201
await module(
242202
inputs,

synalinks/src/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from synalinks.src.api_export import synalinks_export
44

55
# Unique source of truth for the version number.
6-
__version__ = "0.2.023"
6+
__version__ = "0.2.024"
77

88

99
@synalinks_export("synalinks.version")

0 commit comments

Comments
 (0)