Skip to content

Commit 1b40c4c

Browse files
committed
Correct bug in linear function
1 parent 9f286c6 commit 1b40c4c

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

validphys2/src/validphys/theorycovariance/higher_twist_functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,17 @@ def linear_bin_function(
113113
a: npt.ArrayLike, y_shift: npt.ArrayLike, bin_edges: npt.ArrayLike
114114
) -> np.ndarray:
115115
"""
116-
This function defines the linear bin function used to construct the prior. The bins of the
117-
function are constructed using pairs of consecutive points. For instance, given the set of
118-
points [0.0, 0.1, 0.3, 0.5], there will be three bins with edges [[0.0, 0.1], [0.1, 0.3],
119-
0.3, 0.5]]. Each bin is coupled with a shift, which correspond to the y-value of the bin.
116+
This function defines the linear bin function used to construct the prior. Specifically,
117+
the prior is constructed using a triangular function whose value at the peak of the node
118+
is linked to the right and left nodes using a straight line.
120119
121120
Parameters
122121
----------
123122
a: ArrayLike of float
124123
A one-dimensional array of points at which the function is evaluated.
125124
y_shift: ArrayLike of float
126125
A one-dimensional array whose elements represent the y-value of each bin
127-
bin_edges: ArrayLike of float
126+
bin_nodes: ArrayLike of float
128127
A one-dimensional array containing the edges of the bins. The bins are
129128
constructed using pairs of consecutive points.
130129
@@ -153,7 +152,9 @@ def linear_bin_function(
153152
bin_high = bin_mid
154153
m1 = shift / (bin_mid - bin_low)
155154
m2 = 0.0
156-
cond_low = np.multiply(a >= bin_low, a < bin_mid)
155+
cond_low = np.multiply(
156+
a >= bin_low, a < bin_mid if shift_pos != len(y_shift) - 1 else a <= bin_mid
157+
)
157158
cond_high = np.multiply(
158159
a >= bin_mid, a < bin_high if shift_pos != len(y_shift) - 1 else a <= bin_high
159160
)
@@ -1052,7 +1053,6 @@ def average(y_values_pc2_p, y_values_pcL_p, y_values_pc3_p):
10521053
# When this happens, this part must be updated.
10531054
eta = cd_table['kin1'].to_numpy()
10541055
pT = cd_table['kin2'].to_numpy()
1055-
q2 = pT * pT
10561056

10571057
pc_func = JET_pc(pc_jet_nodes, pT, eta, pc_func_type)
10581058
for pars_pc in pars_combs:

0 commit comments

Comments
 (0)