Skip to content

Commit fc4cca8

Browse files
improved readability
1 parent da7331f commit fc4cca8

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

xlumina/vectorized_optics.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
from .toolbox import profile
99
from .wave_optics import build_grid, RS_propagation_jit, build_CZT_grid, CZT_jit, CZT_for_high_NA_jit
1010

11-
# Comment this line if float32 is enough precision for you.
12-
config.update("jax_enable_x64", True)
11+
# Set this to False if f64 is enough precision for you.
12+
enable_float64 = True
13+
d_type = jnp.complex64
14+
if enable_float64:
15+
config.update("jax_enable_x64", True)
16+
d_type = jnp.complex128
1317

1418
"""
1519
Module for vectorized optical fields:
@@ -39,9 +43,9 @@ def __init__(self, x=None, y=None, wavelength=None):
3943
self.k = 2 * jnp.pi / wavelength
4044
self.n = 1
4145
shape = (jnp.shape(x)[0], jnp.shape(y)[0])
42-
self.Ex = jnp.zeros(shape, dtype=jnp.complex128)
43-
self.Ey = jnp.zeros(shape, dtype=jnp.complex128)
44-
self.Ez = jnp.zeros(shape, dtype=jnp.complex128)
46+
self.Ex = jnp.zeros(shape, dtype=d_type)
47+
self.Ey = jnp.zeros(shape, dtype=d_type)
48+
self.Ez = jnp.zeros(shape, dtype=d_type)
4549
self.info = 'Vectorized light'
4650

4751
def draw(self, xlim='', ylim='', kind='', extra_title='', save_file=False, filename=''):
@@ -276,7 +280,7 @@ def VRS_propagation(self, z):
276280
light_out.Ey = E_out[:, :, 1]
277281
light_out.Ez = E_out[:, :, 2]
278282

279-
print("Time taken to perform one VRS propagation (in seconds):", time.perf_counter() - tic)
283+
print(f"Time taken to perform one VRS propagation (in seconds): {(time.perf_counter() - tic):.4f}")
280284
return light_out, quality_factor
281285

282286
def get_VRS_minimum_z(self, n=1, quality_factor=1):
@@ -353,7 +357,7 @@ def VCZT(self, z, xout, yout):
353357
light_out.Ey = E_out[:, :, 1]
354358
light_out.Ez = E_out[:, :, 2]
355359

356-
print("Time taken to perform one VCZT propagation (in seconds):", time.perf_counter() - tic)
360+
print(f"Time taken to perform one VCZT propagation (in seconds): {(time.perf_counter() - tic):.4f}")
357361
return light_out
358362

359363

xlumina/wave_optics.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
from .toolbox import rotate_mask
99

10-
# Comment this line if float32 is enough precision for you.
11-
config.update("jax_enable_x64", True)
10+
# Set this to False if f64 is enough precision for you.
11+
enable_float64 = True
12+
if enable_float64:
13+
config.update("jax_enable_x64", True)
1214

1315
"""
1416
Module for scalar optical fields.
@@ -190,7 +192,7 @@ def RS_propagation(self, z):
190192

191193
propagated_light = ScalarLight(self.x, self.y, self.wavelength)
192194
propagated_light.field = RS_propagation_jit(self.field, z, nx, ny, dx, dy, Xext, Yext, self.k)
193-
print("Time taken to perform one RS propagation (in seconds):", time.perf_counter() - tic)
195+
print(f"Time taken to perform one RS propagation (in seconds): {(time.perf_counter() - tic):.4f}")
194196
return propagated_light, quality_factor
195197

196198
def get_RS_minimum_z(self, n=1, quality_factor=1):
@@ -257,7 +259,7 @@ def CZT(self, z, xout=None, yout=None):
257259
# Build ScalarLight object with output field.
258260
field_out = ScalarLight(xout, yout, self.wavelength)
259261
field_out.field = field_at_z
260-
print("Time taken to perform one CZT propagation (in seconds):", time.perf_counter() - tic)
262+
print(f"Time taken to perform one CZT propagation (in seconds): {(time.perf_counter() - tic):.4f}")
261263
return field_out
262264

263265
def build_grid(x, y):

0 commit comments

Comments
 (0)