Skip to content

Commit 833dbd8

Browse files
committed
speed up convolutions when using pineappl theories by exploiting the
fact that the fktables are already ordered in x1/x2
1 parent 0b3dff7 commit 833dbd8

File tree

4 files changed

+56
-8
lines changed

4 files changed

+56
-8
lines changed

validphys2/src/validphys/convolution.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,13 @@ def _gv_hadron_predictions(loaded_fk, gv1func, gv2func=None):
373373
# possible x1-x2 combinations (f1, f2, x1, x2)
374374
luminosity = np.einsum("ijk, ijl->ijkl", expanded_gv1, expanded_gv2)
375375

376+
if not loaded_fk.legacy:
377+
lx = len(xgrid)
378+
lc = len(fl1)
379+
fktab = sigma.values.reshape(-1, lx, lx, lc)
380+
ret = np.einsum("rcab, nabc->nr", luminosity, fktab)
381+
return pd.DataFrame(ret, index=loaded_fk.data_index)
382+
376383
def appl(df):
377384
# x1 and x2 are encoded as the first and second index levels.
378385
xx1 = df.index.get_level_values(1)
@@ -381,6 +388,12 @@ def appl(df):
381388
partial_lumi = luminosity[..., xx1, xx2]
382389
return pd.Series(np.einsum("ijk,kj->i", partial_lumi, df.values))
383390

391+
# The gv1/gv2 grids are arrays of shape (replicas, flavours<14>, xarray)
392+
# the expanded gv1/gv2 instead are shaped according to the channels (which will match)
393+
# therefore the luminosity is an array of shape (replicas, channels, x1, x2)
394+
# this needs to be matched with the fktable which for the old interface were not ordered
395+
# and so the full dataframe needs to be used instead to keep track of the index
396+
384397
return sigma.groupby(level=0).apply(appl)
385398

386399

@@ -397,6 +410,12 @@ def _gv_dis_predictions(loaded_fk, gvfunc):
397410
if sigma.empty:
398411
return pd.DataFrame(columns=range(gv.shape[0]))
399412

413+
if not loaded_fk.legacy:
414+
lx = len(xgrid)
415+
fktab = sigma.values.reshape(-1, lx, len(fm))
416+
ret = np.einsum("rfa, naf->nr", gv, fktab)
417+
return pd.DataFrame(ret, index=loaded_fk.data_index)
418+
400419
def appl(df):
401420
# x is encoded as the first index level.
402421
xind = df.index.get_level_values(1)

validphys2/src/validphys/coredata.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,25 @@ class FKTableData:
5656
The most common use-case is when a total cross section is used
5757
as a normalization table for a differential cross section,
5858
in legacy code (<= NNPDF4.0) both fktables would be cut using the differential index.
59+
60+
data_index: pd.Series
61+
index of the data points
62+
63+
legacy: bool
64+
If False, this corresponds to an FkTable read from the old applgrid interface.
65+
Deprecated and support will be dropped during the 4.1.X series of tags.
5966
"""
6067

6168
hadronic: bool
6269
Q0: float
6370
ndata: int
6471
xgrid: np.ndarray
6572
sigma: pd.DataFrame
73+
data_index: pd.Series
6674
convolution_types: tuple[str] = None
6775
metadata: dict = dataclasses.field(default_factory=dict, repr=False)
6876
protected: bool = False
77+
legacy: bool = False
6978

7079
def with_cfactor(self, cfactor):
7180
"""Returns a copy of the FKTableData object with cfactors applied to the fktable"""
@@ -123,11 +132,12 @@ def with_cuts(self, cuts):
123132
newndata = len(cuts)
124133
try:
125134
newsigma = self.sigma.loc[cuts]
135+
newdata_idx = self.data_index.loc[cuts]
126136
except KeyError as e:
127137
# This will be an ugly erorr msg, but it should be scary anyway
128138
log.error(f"Problem applying cuts to {self.metadata}")
129139
raise e
130-
return dataclasses.replace(self, ndata=newndata, sigma=newsigma)
140+
return dataclasses.replace(self, ndata=newndata, sigma=newsigma, data_index=newdata_idx)
131141

132142
@property
133143
def luminosity_mapping(self):
@@ -168,8 +178,8 @@ def get_np_fktable(self):
168178
# Make the dataframe into a dense numpy array
169179

170180
# First get the data index out of the way
171-
# this is necessary because cuts/shifts and for performance reasons
172-
# otherwise we will be putting things in a numpy array in very awkward orders
181+
# this is necessary because cuts/shifts and because old fktables are not necessarily ordered
182+
# in addition, for performance reason, we want to order the np array as (ndata, basis, x1, x2)
173183
ns = self.sigma.unstack(level=("data",), fill_value=0)
174184
x1 = ns.index.get_level_values(0)
175185

@@ -244,5 +254,5 @@ class CFactorData:
244254
"""
245255

246256
description: str
247-
central_value: np.array
248-
uncertainty: np.array
257+
central_value: np.ndarray
258+
uncertainty: np.ndarray

validphys2/src/validphys/fkparser.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
res = load_fktable(fk)
1919
"""
2020

21+
# TODO: this module is deprecated and support for older theories will be removed
22+
2123
import dataclasses
2224
import functools
2325
import io
@@ -313,9 +315,17 @@ def parse_fktable(f):
313315
hadronic = res['GridInfo'].hadronic
314316
ndata = res['GridInfo'].ndata
315317
xgrid = res.pop('xGrid')
318+
data_idx = sigma.index.get_level_values("data").unique().to_series()
316319

317320
return FKTableData(
318-
sigma=sigma, ndata=ndata, Q0=Q0, metadata=res, hadronic=hadronic, xgrid=xgrid
321+
sigma=sigma,
322+
ndata=ndata,
323+
Q0=Q0,
324+
metadata=res,
325+
hadronic=hadronic,
326+
xgrid=xgrid,
327+
data_index=data_idx,
328+
legacy=True,
319329
)
320330
elif header_name in _KNOWN_SEGMENTS:
321331
parser = _KNOWN_SEGMENTS[header_name]

validphys2/src/validphys/pineparser.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def pineappl_reader(fkspec):
203203

204204
partial_fktables = []
205205
ndata = 0
206+
full_data_index = []
206207
for fkname, p in zip(fknames, pines):
207208
# Start by reading possible cfactors if cfactor is not empty
208209
cfprod = 1.0
@@ -247,6 +248,7 @@ def pineappl_reader(fkspec):
247248
partial_fktables.append(pd.DataFrame(df_fktable, columns=lumi_columns, index=idx))
248249

249250
ndata += n
251+
full_data_index.append(data_idx)
250252

251253
# Finallly concatenate all fktables, sort by flavours and fill any holes
252254
sigma = pd.concat(partial_fktables, sort=True, copy=False).fillna(0.0)
@@ -265,8 +267,14 @@ def pineappl_reader(fkspec):
265267
ndata = 1
266268

267269
if ndata == 1:
268-
# There's no doubt
269-
protected = divisor == name
270+
# When the number of points is 1 and the fktable is a divisor, protect it from cuts
271+
if divisor == name:
272+
protected = True
273+
full_data_index = [[0]]
274+
275+
# Keeping the data index as a series is exploited to speed up certain operations (e.g. hadronic conv)
276+
fid = np.concatenate(full_data_index)
277+
data_index = pd.Series(fid, index=fid, name="data")
270278

271279
return FKTableData(
272280
sigma=sigma,
@@ -277,4 +285,5 @@ def pineappl_reader(fkspec):
277285
hadronic=hadronic,
278286
xgrid=xgrid,
279287
protected=protected,
288+
data_index=data_index,
280289
)

0 commit comments

Comments
 (0)