Skip to content

Commit 7c33331

Browse files
vectorized loss across detectors is implemented
1 parent 6f8d224 commit 7c33331

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

xlumina/loss_functions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
""" Loss functions:
99
10+
- small_area_hybrid
11+
- vectorized_loss_hybrid
1012
- small_area_STED
1113
- small_area
1214
- mean_batch_MSE_Intensity
@@ -19,6 +21,37 @@
1921
2022
"""
2123

24+
def small_area_hybrid(detected_intensity):
25+
"""
26+
[Small area loss function valid for hybrid (topology + optical parameters) optimization:]
27+
28+
Computes the fraction of intensity comprised inside the area of a mask.
29+
30+
Parameters:
31+
detected_intensity (jnp.array): Detected intensity array
32+
+ epsilon (float): fraction of minimum intensity comprised inside the area.
33+
34+
Return type jnp.array.
35+
"""
36+
epsilon = 0.7
37+
eps = 1e-08
38+
I = detected_intensity / (jnp.sum(detected_intensity) + eps)
39+
mask = jnp.where(I > epsilon*jnp.max(I), 1, 0)
40+
return jnp.sum(mask) / (jnp.sum(mask * I) + eps)
41+
42+
@jit
43+
def vectorized_loss_hybrid(detected_intensities):
44+
"""[For loss_hybrid]: vectorizes loss function to be used across various detectors"""
45+
# Input field has (M, N, N) shape
46+
vloss = vmap(small_area_hybrid, in_axes = (0))
47+
# Call the vectorized function
48+
loss_val = vloss(detected_intensities)
49+
# Returns (M, 1, 1) shape
50+
return loss_val
51+
52+
53+
54+
2255
def small_area_STED(sted_i_effective):
2356
"""
2457
Computes the fraction of intensity comprised inside the area of a mask for STED-like output.

0 commit comments

Comments
 (0)