1111import jax .numpy as jnp
1212from __init__ import um , nm , mm , degrees , radians
1313from xlumina .wave_optics import *
14+ from wave_optics import *
1415from xlumina .toolbox import space
16+ import h5py
1517
1618"""
1719Synthetic data batches generation: 4f system with magnification 2x.
18- - input_fields = jnp.array(in1, in2, ...)
19- - target_fields = jnp.array(out1, out2, ...)
20+ - input_masks = jnp.array(in1, in2, ...)
21+ - target_intensity = jnp.array(out1, out2, ...)
2022"""
2123
2224# System characteristics:
2325sensor_lateral_size = 1024 # Pixel resolution
2426wavelength = 632.8 * nm
2527x_total = 1500 * um
2628x , y = space (x_total , sensor_lateral_size )
29+ X , Y = jnp .meshgrid (x ,y )
2730
2831# Define the light source:
2932w0 = (1200 * um , 1200 * um )
3033gb = LightSource (x , y , wavelength )
31- gb .gaussian_beam ( w0 = w0 , E0 = 1 )
34+ gb .plane_wave ( A = 0.5 )
3235
33- # Data generation functions:
3436def generate_synthetic_circles (gb , num_samples ):
3537 in_circles = []
3638 out_circles = []
3739 for i in range (num_samples ):
3840 r1 = jnp .array (np .random .uniform (100 , 1000 ))
3941 r2 = jnp .array (np .random .uniform (100 , 1000 ))
4042
41- in_circle = gb .apply_circular_mask (r = (r1 , r2 ))
42- in_circles .append (in_circle .field )
43- # Magnification is 2x
43+ # Store only the mask (binary)
44+ in_circle = circular_mask (X , Y , r = (r1 , r2 ))
45+ in_circles .append (in_circle )
46+
47+ # Magnification is 2, store only the itensity
4448 out_circle = gb .apply_circular_mask (r = (2 * r1 , 2 * r2 ))
45- out_circles .append (out_circle .field )
49+ out_circles .append (jnp . abs ( out_circle .field ) ** 2 )
4650 return jnp .array (in_circles ), jnp .array (out_circles )
4751
4852def generate_synthetic_squares (gb , num_samples ):
@@ -53,11 +57,14 @@ def generate_synthetic_squares(gb, num_samples):
5357 height = jnp .array (np .random .uniform (100 , 1000 ))
5458 angle = jnp .array (np .random .uniform (0 , 2 * jnp .pi ))
5559
56- in_square = gb .apply_rectangular_mask (center = (0 ,0 ), width = width , height = height , angle = angle )
57- in_squares .append (in_square .field )
58- # Magnification is 2x
60+ # Binary mask only
61+ in_square = rectangular_mask (X , Y , center = (0 ,0 ), width = width , height = height , angle = angle )
62+ in_squares .append (in_square )
63+
64+ # Magnification is 2 - we only need intensity
5965 out_square = gb .apply_rectangular_mask (center = (0 ,0 ), width = 2 * width , height = 2 * height , angle = - angle )
60- out_squares .append (out_square .field )
66+ out_squares .append (jnp .abs (out_square .field )** 2 )
67+
6168 return jnp .array (in_squares ), jnp .array (out_squares )
6269
6370def generate_synthetic_annular (gb , num_samples ):
@@ -67,11 +74,14 @@ def generate_synthetic_annular(gb, num_samples):
6774 di = jnp .array (np .random .uniform (100 , 500 ))
6875 do = jnp .array (np .random .uniform (550 , 1000 ))
6976
70- in_annular = gb .apply_annular_aperture (di = di , do = do )
71- in_annulars .append (in_annular .field )
72- # Magnification is 2x
77+ # Binary mask only:
78+ in_annular = annular_aperture (di , do , X , Y )
79+ in_annulars .append (in_annular )
80+
81+ # Magnification is 2 - we only need intensity
7382 out_annular = gb .apply_annular_aperture (di = 2 * di , do = 2 * do )
74- out_annulars .append (out_annular .field )
83+ out_annulars .append (jnp .abs (out_annular .field )** 2 )
84+
7585 return jnp .array (in_annulars ), jnp .array (out_annulars )
7686
7787# Data generation loop:
@@ -85,5 +95,9 @@ def generate_synthetic_annular(gb, num_samples):
8595 input_fields = jnp .vstack ([input_circles , input_squares , input_annular ])
8696 target_fields = jnp .vstack ([target_circles , target_squares , target_annular ])
8797
88- filename = f"training_data_4f/synthetic_data_{ s } .npy"
89- np .save (filename , {"Input fields" : input_fields , "Target fields" : target_fields })
98+ filename = f"training_data_4f/synthetic_data_{ s } .hdf5"
99+
100+ with h5py .File (filename , 'w' ) as hdf :
101+ # Create datasets for your data
102+ hdf .create_dataset ("Input fields" , data = input_fields )
103+ hdf .create_dataset ("Target fields" , data = target_fields )
0 commit comments