|
8 | 8 | from .toolbox import profile |
9 | 9 | from .wave_optics import build_grid, RS_propagation_jit, build_CZT_grid, CZT_jit, CZT_for_high_NA_jit |
10 | 10 |
|
11 | | -# Comment this line if float32 is enough precision for you. |
12 | | -config.update("jax_enable_x64", True) |
| 11 | +# Set this to False if f64 is enough precision for you. |
| 12 | +enable_float64 = True |
| 13 | +d_type = jnp.complex64 |
| 14 | +if enable_float64: |
| 15 | + config.update("jax_enable_x64", True) |
| 16 | + d_type = jnp.complex128 |
13 | 17 |
|
14 | 18 | """ |
15 | 19 | Module for vectorized optical fields: |
@@ -39,9 +43,9 @@ def __init__(self, x=None, y=None, wavelength=None): |
39 | 43 | self.k = 2 * jnp.pi / wavelength |
40 | 44 | self.n = 1 |
41 | 45 | shape = (jnp.shape(x)[0], jnp.shape(y)[0]) |
42 | | - self.Ex = jnp.zeros(shape, dtype=jnp.complex128) |
43 | | - self.Ey = jnp.zeros(shape, dtype=jnp.complex128) |
44 | | - self.Ez = jnp.zeros(shape, dtype=jnp.complex128) |
| 46 | + self.Ex = jnp.zeros(shape, dtype=d_type) |
| 47 | + self.Ey = jnp.zeros(shape, dtype=d_type) |
| 48 | + self.Ez = jnp.zeros(shape, dtype=d_type) |
45 | 49 | self.info = 'Vectorized light' |
46 | 50 |
|
47 | 51 | def draw(self, xlim='', ylim='', kind='', extra_title='', save_file=False, filename=''): |
@@ -276,7 +280,7 @@ def VRS_propagation(self, z): |
276 | 280 | light_out.Ey = E_out[:, :, 1] |
277 | 281 | light_out.Ez = E_out[:, :, 2] |
278 | 282 |
|
279 | | - print("Time taken to perform one VRS propagation (in seconds):", time.perf_counter() - tic) |
| 283 | + print(f"Time taken to perform one VRS propagation (in seconds): {(time.perf_counter() - tic):.4f}") |
280 | 284 | return light_out, quality_factor |
281 | 285 |
|
282 | 286 | def get_VRS_minimum_z(self, n=1, quality_factor=1): |
@@ -353,7 +357,7 @@ def VCZT(self, z, xout, yout): |
353 | 357 | light_out.Ey = E_out[:, :, 1] |
354 | 358 | light_out.Ez = E_out[:, :, 2] |
355 | 359 |
|
356 | | - print("Time taken to perform one VCZT propagation (in seconds):", time.perf_counter() - tic) |
| 360 | + print(f"Time taken to perform one VCZT propagation (in seconds): {(time.perf_counter() - tic):.4f}") |
357 | 361 | return light_out |
358 | 362 |
|
359 | 363 |
|
|
0 commit comments