Skip to content

Commit 53b0b45

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into sccm-experimental
2 parents 5d60566 + accf8b4 commit 53b0b45

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

tests/utils/normal_simulator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@ class NormalSimulator(Simulator):
99

1010
def sample(self, batch_shape: Shape, num_observations: int = 32) -> dict[str, Tensor]:
1111
mean = np.random.normal(0.0, 0.1, size=batch_shape + (2,))
12-
mean = np.repeat(mean[:, None], num_observations, axis=1)
13-
1412
std = np.random.lognormal(0.0, 0.1, size=batch_shape + (2,))
15-
std = np.repeat(std[:, None], num_observations, axis=1)
16-
1713
noise = np.random.standard_normal(batch_shape + (num_observations, 2))
1814

19-
x = mean + std * noise
15+
x = mean[:, None] + std[:, None] * noise
16+
# flatten observations for use without summary network
17+
x = x.reshape(x.shape[0], -1)
18+
2019
mean = mean.astype("float32")
2120
std = std.astype("float32")
2221
x = x.astype("float32")

0 commit comments

Comments
 (0)