Skip to content

Commit ef1f47c

Browse files
committed
fix input shape
1 parent 90aa74d commit ef1f47c

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,12 @@ def build(self, xz_shape, conditions_shape=None):
168168

169169
input_shape = list(xz_shape)
170170

171-
# time vector
172-
input_shape[-1] += 1
171+
if self._concatenate_subnet_input:
172+
# time vector
173+
input_shape[-1] += 1
173174

174-
if conditions_shape is not None:
175-
input_shape[-1] += conditions_shape[-1]
175+
if conditions_shape is not None:
176+
input_shape[-1] += conditions_shape[-1]
176177

177178
input_shape = tuple(input_shape)
178179

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,11 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
135135
self.output_projector.units = xz_shape[-1]
136136
input_shape = list(xz_shape)
137137

138-
# construct time vector
139-
input_shape[-1] += 1
140-
if conditions_shape is not None:
141-
input_shape[-1] += conditions_shape[-1]
138+
if self._concatenate_subnet_input:
139+
# construct time vector
140+
input_shape[-1] += 1
141+
if conditions_shape is not None:
142+
input_shape[-1] += conditions_shape[-1]
142143

143144
input_shape = tuple(input_shape)
144145

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,12 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
128128

129129
self.output_projector.units = xz_shape[-1]
130130

131-
# account for concatenating the time and conditions
132131
input_shape = list(xz_shape)
133-
input_shape[-1] += 1
134-
if conditions_shape is not None:
135-
input_shape[-1] += conditions_shape[-1]
132+
if self._concatenate_subnet_input:
133+
# account for concatenating the time and conditions
134+
input_shape[-1] += 1
135+
if conditions_shape is not None:
136+
input_shape[-1] += conditions_shape[-1]
136137
input_shape = tuple(input_shape)
137138

138139
self.subnet.build(input_shape)

0 commit comments

Comments
 (0)