Skip to content

Commit da7331f

Browse files
old loss functions are removed
1 parent c394827 commit da7331f

File tree

1 file changed

+4
-37
lines changed

1 file changed

+4
-37
lines changed

xlumina/loss_functions.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from jax import jit, vmap, config
33
from .__init__ import um
44

5-
# Comment this line if float32 is enough precision for you.
6-
config.update("jax_enable_x64", True)
5+
# Set this to False if f64 is enough precision for you.
6+
enable_float64 = True
7+
if enable_float64:
8+
config.update("jax_enable_x64", True)
79

810
""" Loss functions:
911
1012
- small_area_hybrid
1113
- vectorized_loss_hybrid
12-
- small_area_STED
13-
- small_area
1414
- mean_batch_MSE_Intensity
1515
- vMSE_Amplitude
1616
- vMSE_Phase
@@ -49,39 +49,6 @@ def vectorized_loss_hybrid(detected_intensities):
4949
# Returns (M, 1, 1) shape
5050
return loss_val
5151

52-
53-
54-
55-
def small_area_STED(sted_i_effective):
56-
"""
57-
Computes the fraction of intensity comprised inside the area of a mask for STED-like output.
58-
59-
Parameters:
60-
sted_i_effective (jnp.array): Effective intensity in the focal plane of the objective lens for STED.
61-
+ epsilon (float): fraction of minimum intensity comprised inside the area.
62-
63-
Return loss function (jnp.array).
64-
"""
65-
epsilon = 0.5
66-
I = sted_i_effective / jnp.sum(sted_i_effective)
67-
mask = jnp.where(I > epsilon*jnp.max(I), 1, 0)
68-
return jnp.sum(mask) / (jnp.sum(mask * I))
69-
70-
def small_area(focused_field):
71-
"""
72-
Computes the fraction of intensity comprised inside the area of a mask.
73-
74-
Parameters:
75-
focused_field (object): VectorizedLight in the focal plane of an objective lens.
76-
+ epsilon (float): fraction of minimum intensity comprised inside the area.
77-
78-
Return type jnp.array.
79-
"""
80-
epsilon = 0.5
81-
I = jnp.abs(focused_field.Ez)**2 / jnp.sum(jnp.abs(focused_field.Ez)**2)
82-
mask = jnp.where(I > epsilon*jnp.max(I), 1, 0)
83-
return jnp.sum(mask) / jnp.sum(mask * I)
84-
8552
def mean_batch_MSE_Intensity(optimized, target):
8653
"""
8754
[Computed for batch optimization in 4f system]. Vectorized version of MSE_Intensity.

0 commit comments

Comments
 (0)