Skip to content

Commit e97f55b

Browse files
authored
Merge pull request #353 from NNPDF/merge-channel-factors
Add `Grid::merge_channel_factors` and `Channel::factor`
2 parents a8060ca + a1a75fb commit e97f55b

File tree

11 files changed

+209
-28
lines changed

11 files changed

+209
-28
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515
`pineappl_grid_evolve_info`, and `pineappl_grid_evolve` to evolve grids
1616
- C API: added `pineappl_fktable_optimize` to optimize FK Table-like objects
1717
given an optimization assumption
18+
- added methods `Grid::merge_channel_factors` and `Channel::factor`
1819

1920
## [1.0.0] - 10/06/2025
2021

pineappl/src/boc.rs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,44 @@ impl Channel {
10861086
}
10871087
})
10881088
}
1089+
1090+
/// Finds the factor with the smallest absolute value in the channel and
1091+
/// divides all coefficients by this value.
1092+
///
1093+
/// # Returns
1094+
///
1095+
/// A tuple containing:
1096+
/// - the factored-out coefficient
1097+
/// - a new `Channel` with all coefficients divided by the factored value
1098+
///
1099+
/// # Panics
1100+
///
1101+
/// TODO
1102+
#[must_use]
1103+
pub fn factor(&self) -> (f64, Self) {
1104+
let factor = self
1105+
.entry
1106+
.iter()
1107+
.map(|(_, f)| *f)
1108+
.min_by(|a, b| {
1109+
a.abs()
1110+
.partial_cmp(&b.abs())
1111+
// UNWRAP: if we can't compare the numbers the data structure is bugged
1112+
.unwrap()
1113+
})
1114+
// UNWRAP: every `Channel` has at least one entry
1115+
.unwrap();
1116+
1117+
let new_channel = Self::new(
1118+
self.entry
1119+
.iter()
1120+
.cloned()
1121+
.map(|(e, f)| (e, f / factor))
1122+
.collect(),
1123+
);
1124+
1125+
(factor, new_channel)
1126+
}
10891127
}
10901128

10911129
impl FromStr for Channel {

pineappl/src/fk_table.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
use super::boc::{Channel, Kinematics, Order};
44
use super::convolutions::ConvolutionCache;
55
use super::error::{Error, Result};
6-
use super::grid::Grid;
7-
use super::pids::OptRules;
6+
use super::grid::{Grid, GridOptFlags};
7+
use super::pids::{OptRules, PidBasis};
88
use super::subgrid::{self, EmptySubgridV1, Subgrid};
99
use ndarray::{s, ArrayD};
1010
use std::collections::BTreeMap;
@@ -249,6 +249,14 @@ impl FkTable {
249249
)
250250
}
251251

252+
/// Rotate the FK Table into the specified basis.
253+
pub fn rotate_pid_basis(&mut self, pid_basis: PidBasis) {
254+
self.grid.rotate_pid_basis(pid_basis);
255+
self.grid.split_channels();
256+
self.grid.merge_channel_factors();
257+
self.grid.optimize_using(GridOptFlags::all());
258+
}
259+
252260
/// Optimize the size of this FK-table by throwing away heavy quark flavors assumed to be zero
253261
/// at the FK-table's scales and calling [`Grid::optimize`].
254262
pub fn optimize(&mut self, assumptions: FkAssumptions) {

pineappl/src/grid.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ use std::{iter, mem};
2727
const BIN_AXIS: Axis = Axis(1);
2828

2929
// const ORDER_AXIS: Axis = Axis(0);
30-
// const CHANNEL_AXIS: Axis = Axis(2);
30+
31+
const CHANNEL_AXIS: Axis = Axis(2);
3132

3233
#[derive(Clone, Deserialize, Serialize)]
3334
struct Mmv4;
@@ -1410,6 +1411,21 @@ impl Grid {
14101411
})
14111412
.collect();
14121413
}
1414+
1415+
/// Merges the factors of the channels into the subgrids to normalize channel coefficients.
1416+
///
1417+
/// This method factors out the smallest absolute coefficient from each channel using
1418+
/// [`boc::Channel::factor`] and then scales the corresponding subgrids by these factors.
1419+
pub fn merge_channel_factors(&mut self) {
1420+
let (factors, new_channels): (Vec<_>, Vec<_>) =
1421+
self.channels().iter().map(Channel::factor).unzip();
1422+
1423+
for (mut subgrids_bo, &factor) in self.subgrids.axis_iter_mut(CHANNEL_AXIS).zip(&factors) {
1424+
subgrids_bo.map_inplace(|subgrid| subgrid.scale(factor));
1425+
}
1426+
1427+
self.channels = new_channels;
1428+
}
14131429
}
14141430

14151431
#[cfg(test)]
@@ -1748,6 +1764,32 @@ mod tests {
17481764
assert_eq!(grid.orders().len(), 1);
17491765
}
17501766

1767+
#[test]
1768+
fn grid_merge_channel_factors() {
1769+
let mut grid = Grid::new(
1770+
BinsWithFillLimits::from_fill_limits([0.0, 1.0].to_vec()).unwrap(),
1771+
vec![Order::new(0, 2, 0, 0, 0)],
1772+
vec![Channel::new(vec![(vec![1, -1], 0.5), (vec![2, -2], 2.5)])],
1773+
PidBasis::Pdg,
1774+
vec![Conv::new(ConvType::UnpolPDF, 2212); 2],
1775+
v0::default_interps(false, 2),
1776+
vec![Kinematics::Scale(0), Kinematics::X(0), Kinematics::X(1)],
1777+
Scales {
1778+
ren: ScaleFuncForm::Scale(0),
1779+
fac: ScaleFuncForm::Scale(0),
1780+
frg: ScaleFuncForm::NoScale,
1781+
},
1782+
);
1783+
1784+
grid.merge_channel_factors();
1785+
grid.channels().iter().all(|channel| {
1786+
channel
1787+
.entry()
1788+
.iter()
1789+
.all(|(_, fac)| (*fac - 1.0).abs() < f64::EPSILON)
1790+
});
1791+
}
1792+
17511793
#[test]
17521794
fn grid_convolutions() {
17531795
let mut grid = Grid::new(

pineappl_cli/src/write.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ enum OpsArg {
3737
DeleteOrders(Vec<RangeInclusive<usize>>),
3838
DivBinNormDims(Vec<usize>),
3939
MergeBins(Vec<RangeInclusive<usize>>),
40+
MergeChannelFactors(bool),
4041
MulBinNorm(f64),
4142
Optimize(bool),
4243
OptimizeFkTable(FkAssumptions),
@@ -84,7 +85,7 @@ impl FromArgMatches for MoreArgs {
8485
});
8586
}
8687
}
87-
"optimize" | "split_channels" | "upgrade" => {
88+
"merge_channel_factors" | "optimize" | "split_channels" | "upgrade" => {
8889
let arguments: Vec<Vec<_>> = matches
8990
.remove_occurrences(&id)
9091
.unwrap()
@@ -95,6 +96,7 @@ impl FromArgMatches for MoreArgs {
9596
for (index, arg) in indices.into_iter().zip(arguments.into_iter()) {
9697
assert_eq!(arg.len(), 1);
9798
args[index] = Some(match id.as_str() {
99+
"merge_channel_factors" => OpsArg::MergeChannelFactors(arg[0]),
98100
"optimize" => OpsArg::Optimize(arg[0]),
99101
"split_channels" => OpsArg::SplitChannels(arg[0]),
100102
"upgrade" => OpsArg::Upgrade(arg[0]),
@@ -346,6 +348,17 @@ impl Args for MoreArgs {
346348
.value_name("BIN1-BIN2,...")
347349
.value_parser(helpers::parse_integer_range),
348350
)
351+
.arg(
352+
Arg::new("merge_channel_factors")
353+
.action(ArgAction::Append)
354+
.default_missing_value("true")
355+
.help("Merge channel factors into the grid")
356+
.long("merge-channel-factors")
357+
.num_args(0..=1)
358+
.require_equals(true)
359+
.value_name("ON")
360+
.value_parser(clap::value_parser!(bool)),
361+
)
349362
.arg(
350363
Arg::new("mul_bin_norm")
351364
.action(ArgAction::Append)
@@ -551,6 +564,7 @@ impl Subcommand for Opts {
551564
grid.merge_bins(range)?;
552565
}
553566
}
567+
OpsArg::MergeChannelFactors(true) => grid.merge_channel_factors(),
554568
OpsArg::MulBinNorm(factor) => {
555569
grid.set_bwfl(
556570
BinsWithFillLimits::new(
@@ -603,8 +617,10 @@ impl Subcommand for Opts {
603617
}
604618
OpsArg::SplitChannels(true) => grid.split_channels(),
605619
OpsArg::Upgrade(true) => grid.upgrade(),
606-
OpsArg::Optimize(false) | OpsArg::SplitChannels(false) | OpsArg::Upgrade(false) => {
607-
}
620+
OpsArg::MergeChannelFactors(false)
621+
| OpsArg::Optimize(false)
622+
| OpsArg::SplitChannels(false)
623+
| OpsArg::Upgrade(false) => {}
608624
}
609625
}
610626

pineappl_cli/tests/write.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Options:
2020
--delete-key <KEY> Delete an internal key-value pair
2121
--div-bin-norm-dims <DIM1,...> Divide each bin normalizations by the bin lengths for the given dimensions
2222
--merge-bins <BIN1-BIN2,...> Merge specific bins together
23+
--merge-channel-factors[=<ON>] Merge channel factors into the grid [possible values: true, false]
2324
--mul-bin-norm <NORM> Multiply all bin normalizations with the given factor
2425
--optimize[=<ENABLE>] Optimize internal data structure to minimize memory and disk usage [possible values: true, false]
2526
--optimize-fk-table <OPTIMI> Optimize internal data structure of an FkTable to minimize memory and disk usage [possible values: Nf6Ind, Nf6Sym, Nf5Ind, Nf5Sym, Nf4Ind, Nf4Sym, Nf3Ind, Nf3Sym]

pineappl_py/src/fk_table.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,8 @@ impl PyFkTable {
224224
/// ----------
225225
/// pid_basis: PyPidBasis
226226
/// PID basis of the resulting FK Table
227-
pub fn rotate_pid_basis(&mut self, pid_basis: PyPidBasis) -> PyGrid {
228-
let mut grid_mut = self.fk_table.grid().clone();
229-
grid_mut.rotate_pid_basis(pid_basis.into());
230-
PyGrid {
231-
grid: grid_mut.clone(),
232-
}
227+
pub fn rotate_pid_basis(&mut self, pid_basis: PyPidBasis) {
228+
self.fk_table.rotate_pid_basis(pid_basis.into());
233229
}
234230

235231
/// Write to file.

pineappl_py/src/grid.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,11 @@ impl PyGrid {
712712
self.grid.rotate_pid_basis(pid_basis.into());
713713
}
714714

715+
/// Merge the factors of all the channels.
716+
pub fn merge_channel_factors(&mut self) {
717+
self.grid.merge_channel_factors();
718+
}
719+
715720
/// Scale all subgrids.
716721
///
717722
/// Parameters

pineappl_py/tests/test_boc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,13 @@ def _generated_bwfl_fields(n_bins: int, n_dimensions: int) -> BwflFields:
5959

6060
class TestChannel:
6161
def test_init(self):
62-
le = Channel([([2, 2], 0.5)])
63-
assert isinstance(le, Channel)
64-
assert le.into_array() == [([2, 2], 0.5)]
62+
channel = Channel([([2, -2], 0.5)])
63+
assert isinstance(channel, Channel)
64+
assert channel.into_array() == [([2, -2], 0.5)]
65+
66+
channels = Channel([([2, -2], 0.5), ([3, -3], 1.5)])
67+
assert isinstance(channels, Channel)
68+
assert channels.into_array() == [([2, -2], 0.5), ([3, -3], 1.5)]
6569

6670

6771
class TestKinematics:

pineappl_py/tests/test_fk_table.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"""
66

77
import numpy as np
8-
import tempfile
98

109
from pineappl.boc import Channel, Order
1110
from pineappl.convolutions import Conv, ConvType
@@ -15,7 +14,7 @@
1514

1615

1716
class TestFkTable:
18-
def test_convolve(self, fake_grids):
17+
def test_convolve(self, fake_grids, tmp_path):
1918
# Define convolution types and the initial state hadrons
2019
# We consider an initial state Polarized Proton
2120
h = ConvType(polarized=True, time_like=False)
@@ -65,9 +64,9 @@ def test_convolve(self, fake_grids):
6564
)
6665

6766
# Test writing/dumping the FK table into disk
68-
with tempfile.TemporaryDirectory() as tmpdir:
69-
fk.write(f"{tmpdir}/toy_fktable.pineappl")
70-
fk.write_lz4(f"{tmpdir}/toy_fktable.pineappl.lz4")
67+
path = f"{tmp_path}/toy_fktable.pineappl"
68+
fk.write(path)
69+
fk.write_lz4(path)
7170

7271
def test_fktable(
7372
self,
@@ -113,8 +112,38 @@ def test_fktable(
113112

114113
# Check that FK table is in the Evolution basis and rotate into PDG
115114
assert fk.pid_basis == PidBasis.Evol
116-
new_fk = fk.rotate_pid_basis(PidBasis.Pdg)
117-
assert new_fk.pid_basis == PidBasis.Pdg
115+
fk.rotate_pid_basis(PidBasis.Pdg)
116+
assert fk.pid_basis == PidBasis.Pdg
117+
118+
def test_fktable_rotations(
119+
self,
120+
pdf,
121+
download_objects,
122+
tmp_path,
123+
fkname: str = "FKTABLE_CMSTTBARTOT8TEV-TOPDIFF8TEVTOT.pineappl.lz4",
124+
):
125+
expected_results = [3.72524538e04] # Numbers computed using `v0.8.6`
126+
127+
fk_table = download_objects(f"{fkname}")
128+
fk = FkTable.read(fk_table)
129+
130+
# rotate in the PDG basis and check that all the factors are unity
131+
fk.rotate_pid_basis(PidBasis.Pdg)
132+
assert fk.pid_basis == PidBasis.Pdg
133+
134+
# check that the convolutions are still the same
135+
np.testing.assert_allclose(
136+
fk.convolve(
137+
pdg_convs=fk.convolutions,
138+
xfxs=[pdf.unpolarized_pdf, pdf.unpolarized_pdf],
139+
),
140+
expected_results,
141+
)
142+
143+
# check that the FK table can be loaded properly
144+
path = f"{tmp_path}/rotated_fktable.pineappl.lz4"
145+
fk.write_lz4(path)
146+
_ = FkTable.read(path)
118147

119148
def test_unpolarized_convolution(
120149
self,
@@ -125,7 +154,7 @@ def test_unpolarized_convolution(
125154
"""Check the convolution of an actual FK table that involves two
126155
symmetrical unpolarized protons:
127156
"""
128-
expected_results = [3.72524538e04]
157+
expected_results = [3.72524538e04] # Numbers computed using `v0.8.6`
129158
fk_table = download_objects(f"{fkname}")
130159
fk = FkTable.read(fk_table)
131160

0 commit comments

Comments
 (0)