|
2 | 2 | from jax import jit, vmap, config |
3 | 3 | from .__init__ import um |
4 | 4 |
|
5 | | -# Comment this line if float32 is enough precision for you. |
6 | | -config.update("jax_enable_x64", True) |
| 5 | +# Set this to False if f64 is enough precision for you. |
| 6 | +enable_float64 = True |
| 7 | +if enable_float64: |
| 8 | + config.update("jax_enable_x64", True) |
7 | 9 |
|
8 | 10 | """ Loss functions: |
9 | 11 |
|
10 | 12 | - small_area_hybrid |
11 | 13 | - vectorized_loss_hybrid |
12 | | - - small_area_STED |
13 | | - - small_area |
14 | 14 | - mean_batch_MSE_Intensity |
15 | 15 | - vMSE_Amplitude |
16 | 16 | - vMSE_Phase |
@@ -49,39 +49,6 @@ def vectorized_loss_hybrid(detected_intensities): |
49 | 49 | # Returns (M, 1, 1) shape |
50 | 50 | return loss_val |
51 | 51 |
|
52 | | - |
53 | | - |
54 | | - |
55 | | -def small_area_STED(sted_i_effective): |
56 | | - """ |
57 | | - Computes the fraction of intensity comprised inside the area of a mask for STED-like output. |
58 | | -
|
59 | | - Parameters: |
60 | | - sted_i_effective (jnp.array): Effective intensity in the focal plane of the objective lens for STED. |
61 | | - + epsilon (float): fraction of minimum intensity comprised inside the area. |
62 | | - |
63 | | - Return loss function (jnp.array). |
64 | | - """ |
65 | | - epsilon = 0.5 |
66 | | - I = sted_i_effective / jnp.sum(sted_i_effective) |
67 | | - mask = jnp.where(I > epsilon*jnp.max(I), 1, 0) |
68 | | - return jnp.sum(mask) / (jnp.sum(mask * I)) |
69 | | - |
70 | | -def small_area(focused_field): |
71 | | - """ |
72 | | - Computes the fraction of intensity comprised inside the area of a mask. |
73 | | -
|
74 | | - Parameters: |
75 | | - focused_field (object): VectorizedLight in the focal plane of an objective lens. |
76 | | - + epsilon (float): fraction of minimum intensity comprised inside the area. |
77 | | - |
78 | | - Return type jnp.array. |
79 | | - """ |
80 | | - epsilon = 0.5 |
81 | | - I = jnp.abs(focused_field.Ez)**2 / jnp.sum(jnp.abs(focused_field.Ez)**2) |
82 | | - mask = jnp.where(I > epsilon*jnp.max(I), 1, 0) |
83 | | - return jnp.sum(mask) / jnp.sum(mask * I) |
84 | | - |
85 | 52 | def mean_batch_MSE_Intensity(optimized, target): |
86 | 53 | """ |
87 | 54 | [Computed for batch optimization in 4f system]. Vectorized version of MSE_Intensity. |
|
0 commit comments