Skip to content

Commit e892e31

Browse files
tests for wave optics module
1 parent 8e0f969 commit e892e31

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

tests/test_wave_optics.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Test for wave optics module
2+
import os
3+
import sys
4+
5+
# Setting the path for XLuminA modules:
6+
current_path = os.path.abspath(os.path.join('..'))
7+
module_path = os.path.join(current_path)
8+
9+
if module_path not in sys.path:
10+
sys.path.append(module_path)
11+
12+
import unittest
13+
import jax.numpy as jnp
14+
import numpy as np
15+
from xlumina.wave_optics import ScalarLight, LightSource
16+
17+
class TestWaveOptics(unittest.TestCase):
18+
def setUp(self):
19+
self.wavelength = 633e-3 #nm
20+
self.resolution = 1024
21+
self.x = np.linspace(-1500, 1500, self.resolution)
22+
self.y = np.linspace(-1500, 1500, self.resolution)
23+
self.k = 2 * jnp.pi / self.wavelength
24+
25+
def test_scalar_light(self):
26+
light = ScalarLight(self.x, self.y, self.wavelength)
27+
self.assertEqual(light.wavelength, self.wavelength)
28+
self.assertEqual(light.k, self.k)
29+
self.assertEqual(light.field.shape, (self.resolution, self.resolution))
30+
31+
def test_light_source_gb(self):
32+
source = LightSource(self.x, self.y, self.wavelength)
33+
source.gaussian_beam(w0=(1200, 1200), E0=1)
34+
self.assertEqual(source.wavelength, self.wavelength)
35+
self.assertEqual(source.field.shape, (self.resolution, self.resolution))
36+
self.assertGreater(jnp.sum(jnp.abs(source.field)**2), 0)
37+
38+
def test_light_source_pw(self):
39+
source = LightSource(self.x, self.y, self.wavelength)
40+
source.plane_wave(A=1, theta=0, phi=0, z0=0)
41+
self.assertGreater(jnp.sum(jnp.abs(source.field)**2), 0)
42+
43+
def test_rs_propagation(self):
44+
light = LightSource(self.x, self.y, self.wavelength)
45+
light.gaussian_beam(w0=(1200, 1200), E0=1)
46+
propagated, _ = light.RS_propagation(z=1000)
47+
self.assertEqual(propagated.field.shape, (self.resolution, self.resolution))
48+
49+
def test_czt(self):
50+
light = LightSource(self.x, self.y, self.wavelength)
51+
light.gaussian_beam(w0=(1200, 1200), E0=1)
52+
propagated = light.CZT(z=1000)
53+
self.assertEqual(propagated.field.shape, (self.resolution, self.resolution))
54+
55+
if __name__ == '__main__':
56+
unittest.main()

0 commit comments

Comments
 (0)