Skip to content

Commit 8d929c6

Browse files
Ensures that EKOs are passed in Q2 slices in grid.evolve of the Python API
1 parent 5a0fde1 commit 8d929c6

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

pineappl_py/src/grid.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use pineappl::grid::Grid;
1717
use pineappl::pids::PidBasis;
1818
use pyo3::exceptions::PyValueError;
1919
use pyo3::prelude::*;
20+
use pyo3::types::PyTuple;
2021
use std::collections::BTreeMap;
2122
use std::fs::File;
2223
use std::io::BufReader;
@@ -468,9 +469,8 @@ impl PyGrid {
468469
///
469470
/// Parameters
470471
/// ----------
471-
/// slices : list(list(tuple(PyOperatorSliceInfo, PyReadOnlyArray4)))
472-
/// list of EKOs where each element is a list of (PyOperatorSliceInfo, 4D array)
473-
/// describing each convolution
472+
/// slices : list(Generator(tuple(PyOperatorSliceInfo, PyReadOnlyArray4)))
473+
/// list of EKOs where each element is in turn a list of (PyOperatorSliceInfo, 4D array)
474474
/// order_mask : numpy.ndarray(bool)
475475
/// boolean mask to activate orders
476476
/// xi : (float, float)
@@ -486,7 +486,7 @@ impl PyGrid {
486486
/// produced FK table
487487
pub fn evolve(
488488
&self,
489-
slices: Vec<Vec<(PyOperatorSliceInfo, PyReadonlyArray4<f64>)>>,
489+
slices: Vec<Bound<PyAny>>,
490490
order_mask: Vec<bool>,
491491
xi: (f64, f64, f64),
492492
ren1: Vec<f64>,
@@ -496,12 +496,19 @@ impl PyGrid {
496496
.grid
497497
.evolve(
498498
slices
499-
.iter()
499+
.into_iter()
500500
.map(|subslice| {
501-
subslice.iter().map(|(info, op)| {
501+
// create lazy iterators from Python object
502+
subslice.try_iter().unwrap().map(|item| {
503+
let item = item.unwrap();
504+
let op_tuple = item.downcast::<PyTuple>().unwrap();
505+
let info: PyOperatorSliceInfo =
506+
op_tuple.get_item(0).unwrap().extract().unwrap();
507+
let op: PyReadonlyArray4<f64> =
508+
op_tuple.get_item(1).unwrap().extract().unwrap();
509+
502510
Ok::<_, std::io::Error>((
503-
info.info.clone(),
504-
// TODO: avoid copying
511+
info.info,
505512
CowArray::from(op.as_array().to_owned()),
506513
))
507514
})
@@ -512,7 +519,6 @@ impl PyGrid {
512519
&AlphasTable { ren1, alphas },
513520
)
514521
.map(|fk_table| PyFkTable { fk_table })
515-
// TODO: avoid unwrap and convert `Result` into `PyResult`
516522
.unwrap())
517523
}
518524

pineappl_py/tests/test_grid.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -601,10 +601,8 @@ def test_evolve_with_two_ekos(
601601
h2 = ConvType(polarized=False, time_like=False)
602602
convolution_types = [h1, h2]
603603

604-
input_xgrid = np.geomspace(2e-7, 1, num=50)
605-
slices = []
606-
for conv_id, cvtype in enumerate(convolution_types):
607-
sub_slices = []
604+
def _q2_slices(cvtype):
605+
"""Pass one Q2 at a time using Generators."""
608606
for q2 in evinfo.fac1:
609607
info = OperatorSliceInfo(
610608
fac0=1.0,
@@ -626,8 +624,12 @@ def test_evolve_with_two_ekos(
626624
input_xgrid.size,
627625
),
628626
)
629-
sub_slices.append((info, op))
630-
slices.append(sub_slices)
627+
yield (info, op)
628+
629+
input_xgrid = np.geomspace(2e-7, 1, num=50)
630+
631+
# construct the slice object -- list of generator of slices
632+
slices = [_q2_slices(cvtype=cvtype) for cvtype in convolution_types]
631633

632634
fktable = g.evolve(
633635
slices=slices,

0 commit comments

Comments
 (0)