Skip to content

Commit a0fe320

Browse files
new synthetic data generator for 4f
1 parent 71e9bb5 commit a0fe320

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

experiments/generate_synthetic_data.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,42 @@
1111
import jax.numpy as jnp
1212
from __init__ import um, nm, mm, degrees, radians
1313
from xlumina.wave_optics import *
14+
from wave_optics import *
1415
from xlumina.toolbox import space
16+
import h5py
1517

1618
"""
1719
Synthetic data batches generation: 4f system with magnification 2x.
18-
- input_fields = jnp.array(in1, in2, ...)
19-
- target_fields = jnp.array(out1, out2, ...)
20+
- input_masks = jnp.array(in1, in2, ...)
21+
- target_intensity = jnp.array(out1, out2, ...)
2022
"""
2123

2224
# System characteristics:
2325
sensor_lateral_size = 1024 # Pixel resolution
2426
wavelength = 632.8*nm
2527
x_total = 1500*um
2628
x, y = space(x_total, sensor_lateral_size)
29+
X, Y = jnp.meshgrid(x,y)
2730

2831
# Define the light source:
2932
w0 = (1200*um , 1200*um)
3033
gb = LightSource(x, y, wavelength)
31-
gb.gaussian_beam(w0=w0, E0=1)
34+
gb.plane_wave(A=0.5)
3235

33-
# Data generation functions:
3436
def generate_synthetic_circles(gb, num_samples):
3537
in_circles = []
3638
out_circles = []
3739
for i in range(num_samples):
3840
r1 = jnp.array(np.random.uniform(100, 1000))
3941
r2 = jnp.array(np.random.uniform(100, 1000))
4042

41-
in_circle = gb.apply_circular_mask(r=(r1, r2))
42-
in_circles.append(in_circle.field)
43-
# Magnification is 2x
43+
# Store only the mask (binary)
44+
in_circle = circular_mask(X, Y, r=(r1, r2))
45+
in_circles.append(in_circle)
46+
47+
# Magnification is 2, store only the itensity
4448
out_circle = gb.apply_circular_mask(r=(2*r1, 2*r2))
45-
out_circles.append(out_circle.field)
49+
out_circles.append(jnp.abs(out_circle.field)**2)
4650
return jnp.array(in_circles), jnp.array(out_circles)
4751

4852
def generate_synthetic_squares(gb, num_samples):
@@ -53,11 +57,14 @@ def generate_synthetic_squares(gb, num_samples):
5357
height = jnp.array(np.random.uniform(100, 1000))
5458
angle = jnp.array(np.random.uniform(0, 2*jnp.pi))
5559

56-
in_square = gb.apply_rectangular_mask(center=(0,0), width=width, height=height, angle=angle)
57-
in_squares.append(in_square.field)
58-
# Magnification is 2x
60+
# Binary mask only
61+
in_square = rectangular_mask(X, Y, center=(0,0), width=width, height=height, angle=angle)
62+
in_squares.append(in_square)
63+
64+
# Magnification is 2 - we only need intensity
5965
out_square = gb.apply_rectangular_mask(center=(0,0), width=2*width, height=2*height, angle=-angle)
60-
out_squares.append(out_square.field)
66+
out_squares.append(jnp.abs(out_square.field)**2)
67+
6168
return jnp.array(in_squares), jnp.array(out_squares)
6269

6370
def generate_synthetic_annular(gb, num_samples):
@@ -67,11 +74,14 @@ def generate_synthetic_annular(gb, num_samples):
6774
di = jnp.array(np.random.uniform(100, 500))
6875
do = jnp.array(np.random.uniform(550, 1000))
6976

70-
in_annular = gb.apply_annular_aperture(di=di, do=do)
71-
in_annulars.append(in_annular.field)
72-
# Magnification is 2x
77+
# Binary mask only:
78+
in_annular = annular_aperture(di, do, X, Y)
79+
in_annulars.append(in_annular)
80+
81+
# Magnification is 2 - we only need intensity
7382
out_annular = gb.apply_annular_aperture(di=2*di, do=2*do)
74-
out_annulars.append(out_annular.field)
83+
out_annulars.append(jnp.abs(out_annular.field)**2)
84+
7585
return jnp.array(in_annulars), jnp.array(out_annulars)
7686

7787
# Data generation loop:
@@ -85,5 +95,9 @@ def generate_synthetic_annular(gb, num_samples):
8595
input_fields = jnp.vstack([input_circles, input_squares, input_annular])
8696
target_fields = jnp.vstack([target_circles, target_squares, target_annular])
8797

88-
filename = f"training_data_4f/synthetic_data_{s}.npy"
89-
np.save(filename, {"Input fields": input_fields, "Target fields": target_fields})
98+
filename = f"training_data_4f/synthetic_data_{s}.hdf5"
99+
100+
with h5py.File(filename, 'w') as hdf:
101+
# Create datasets for your data
102+
hdf.create_dataset("Input fields", data=input_fields)
103+
hdf.create_dataset("Target fields", data=target_fields)

0 commit comments

Comments
 (0)