Skip to content

Commit ef3892e

Browse files
committed
Fix NormalSimulator output shapes
1 parent 3e7cea5 commit ef3892e

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tests/utils/normal_simulator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def sample(self, batch_shape: Shape, num_observations: int = 32) -> dict[str, Te
1717
noise = np.random.standard_normal(batch_shape + (num_observations, 2))
1818

1919
x = mean + std * noise
20+
x = x.reshape(x.shape[0], -1)
21+
mean = mean[:, 0]
22+
std = std[:, 0]
2023
mean = mean.astype("float32")
2124
std = std.astype("float32")
2225
x = x.astype("float32")

0 commit comments

Comments
 (0)