Skip to content

Commit 39f17ee

Browse files
committed
Spelling, inline _forward and comments
1 parent 7955cbd commit 39f17ee

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

bayesflow/networks/point_inference_network.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
6969
# is resolved, a flat version of the heads dictionary is kept.
7070
# This allows to save head weights properly, see for reference
7171
# https://github.com/keras-team/keras/blob/v3.3.3/keras/src/saving/saving_lib.py#L481.
72-
# A nested heads dict is still prefered over this flat dict,
72+
# A nested heads dict is still preferred over this flat dict,
7373
# because it avoids string operation based filtering in `self._forward()`.
7474
flat_key = f"{score_key}___{head_key}"
7575
self.heads_flat[flat_key] = head
@@ -96,17 +96,12 @@ def call(
9696
if xz is None and not self.built:
9797
raise ValueError("Cannot build inference network without inference variables.")
9898
if conditions is None: # unconditional estimation uses a fixed input vector
99-
conditions = keras.ops.convert_to_tensor([[1.0]], dtype="float32")
100-
return self._forward(conditions=conditions, training=training, **kwargs)
99+
conditions = keras.ops.convert_to_tensor([[1.0]], dtype=keras.ops.dtype(xz))
101100

102-
def _forward(
103-
self,
104-
conditions: Tensor = None,
105-
training: bool = False,
106-
**kwargs,
107-
) -> dict[str, Tensor]:
101+
# pass conditions to the shared subnet
108102
output = self.subnet(conditions, training=training)
109103

104+
# pass along to calculate individual head outputs
110105
output = {
111106
score_key: {head_key: head(output, training=training) for head_key, head in self.heads[score_key].items()}
112107
for score_key in self.heads.keys()

0 commit comments

Comments
 (0)