|
228 | 228 | " delta)\n", |
229 | 229 | " return accumulate, None\n", |
230 | 230 | "\n", |
231 | | - " accumulate = (\n", |
| 231 | + " init_accumulate = (\n", |
232 | 232 | " jnp.zeros(l.shape, jnp.float32), jnp.zeros(l.shape, jnp.float32),\n", |
233 | | - " jnp.zeros(1, jnp.float32), jnp.zeros(1, jnp.float32),\n", |
234 | | - " jnp.zeros(1, jnp.float32), jnp.zeros(1, jnp.float32),\n", |
235 | | - " jnp.zeros(1, jnp.float32), jnp.zeros(1, jnp.float32)\n", |
| 233 | + " jnp.zeros((), jnp.float32), jnp.zeros((), jnp.float32),\n", |
| 234 | + " jnp.zeros((), jnp.float32), jnp.zeros((), jnp.float32),\n", |
| 235 | + " jnp.zeros((), jnp.float32), jnp.zeros((), jnp.float32)\n", |
236 | 236 | " )\n", |
237 | | - " accuulate, _ = jax.lax.scan(\n", |
| 237 | + " accumulate, _ = jax.lax.scan(\n", |
238 | 238 | " accumulate_over_freq,\n", |
239 | | - " accumulate,\n", |
| 239 | + " init_accumulate,\n", |
240 | 240 | " (freqs, jax.random.split(key, len(freqs)))\n", |
241 | 241 | " )\n", |
242 | 242 | "\n", |
|
248 | 248 | " ) = accumulate\n", |
249 | 249 | "\n", |
250 | 250 | " # Compute RMS and image normal stats\n", |
251 | | - " rms_no_noise = jnp.sqrt(jnp.sum((image - zero_point) ** 2))\n", |
| 251 | + " rms_no_noise = jnp.sqrt(jnp.mean((image - zero_point) ** 2))\n", |
252 | 252 | " max_no_noise = jnp.max(image)\n", |
253 | 253 | " min_no_noise = jnp.min(image)\n", |
254 | 254 | " mean_no_noise = jnp.mean(image)\n", |
255 | 255 | " std_no_noise = jnp.std(image)\n", |
256 | 256 | "\n", |
257 | | - " rms_noise = jnp.sqrt(jnp.sum((image_noise - zero_point) ** 2))\n", |
| 257 | + " rms_noise = jnp.sqrt(jnp.mean((image_noise - zero_point) ** 2))\n", |
258 | 258 | " max_noise = jnp.max(image_noise)\n", |
259 | 259 | " min_noise = jnp.min(image_noise)\n", |
260 | 260 | " mean_noise = jnp.mean(image_noise)\n", |
|
0 commit comments