@@ -135,17 +135,31 @@ def linear_bin_function(
135135 """
136136 res = np .zeros_like (a )
137137 for shift_pos , shift in enumerate (y_shift ):
138- bin_low = bin_edges [shift_pos ]
139- bin_high = bin_edges [shift_pos + 1 ]
140- bin_mid = 0.5 * (bin_low + bin_high )
138+ if shift_pos > 0 and shift_pos < len (y_shift ) - 1 :
139+ bin_low = bin_edges [shift_pos - 1 ]
140+ bin_high = bin_edges [shift_pos + 1 ]
141+ bin_mid = bin_edges [shift_pos ]
142+ m1 = shift / (bin_mid - bin_low )
143+ m2 = shift / (bin_high - bin_mid )
144+ elif shift_pos == 0 : # Left-most bin
145+ bin_high = bin_edges [shift_pos + 1 ]
146+ bin_mid = bin_edges [shift_pos ]
147+ bin_low = bin_mid
148+ m1 = 0.0
149+ m2 = shift / (bin_high - bin_mid )
150+ else : # Right-most bin
151+ bin_low = bin_edges [shift_pos - 1 ]
152+ bin_mid = bin_edges [shift_pos ]
153+ bin_high = bin_mid
154+ m1 = shift / (bin_mid - bin_low )
155+ m2 = 0.0
141156 cond_low = np .multiply (a >= bin_low , a < bin_mid )
142157 cond_high = np .multiply (
143158 a >= bin_mid , a < bin_high if shift_pos != len (y_shift ) - 1 else a <= bin_high
144159 )
145- m = 2 * shift / (bin_high - bin_low )
146- res = np .add (res , [m * (val - bin_low ) if cond else 0.0 for val , cond in zip (a , cond_low )])
160+ res = np .add (res , [m1 * (val - bin_low ) if cond else 0.0 for val , cond in zip (a , cond_low )])
147161 res = np .add (
148- res , [- m * (val - bin_high ) if cond else 0.0 for val , cond in zip (a , cond_high )]
162+ res , [- m2 * (val - bin_high ) if cond else 0.0 for val , cond in zip (a , cond_high )]
149163 )
150164 return res
151165
0 commit comments