Skip to content

Commit 0ebdc8a

Browse files
authored
FEAT: add DefinedExpression.subexpressions attribute (#191)
1 parent 2ef3479 commit 0ebdc8a

File tree

1 file changed

+53
-40
lines changed

1 file changed

+53
-40
lines changed

src/ampform_dpd/__init__.py

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import functools
66
import operator
77
from collections import abc
8-
from functools import cache
8+
from collections.abc import Callable
9+
from functools import cache, wraps
910
from itertools import product
10-
from typing import TYPE_CHECKING, Protocol
11+
from typing import TYPE_CHECKING, Any, Protocol
1112
from warnings import warn
1213

14+
import attrs
1315
import sympy as sp
1416
from ampform.kinematics.phasespace import compute_third_mandelstam
1517
from ampform.sympy import PoolSum
@@ -35,7 +37,7 @@
3537
from ampform_dpd.spin import create_spin_range
3638

3739
if TYPE_CHECKING:
38-
from collections.abc import Iterable
40+
from collections.abc import Callable, Iterable
3941
from typing import Any, Literal
4042

4143

@@ -181,12 +183,10 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
181183
self.decay.final_state[3].spin,
182184
)
183185
λR = sp.Symbol(R"\lambda_R", rational=True)
184-
terms = []
185-
parameter_defaults: dict[sp.Basic, complex] = {}
186+
amplitude_sum = DefinedExpression(0)
186187
for chain in self.decay.get_subsystem(subsystem_id).chains:
187188
formulate_dynamics = self.dynamics_choices.get_builder(chain.resonance.name)
188-
dynamics = formulate_dynamics(chain)
189-
parameter_defaults.update(dynamics.parameters)
189+
amplitude = formulate_dynamics(chain)
190190
resonance_spin = sp.Rational(chain.resonance.spin)
191191
resonance_helicities = create_spin_range(resonance_spin)
192192
for λR_val in resonance_helicities:
@@ -200,26 +200,25 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
200200
)
201201
if isinstance(scaling_factors, tuple):
202202
h_prod, h_dec = scaling_factors
203-
parameter_defaults[h_prod] = 1 + 0j
204-
parameter_defaults[h_dec] = 1
203+
amplitude.parameters[h_prod] = 1 + 0j
204+
amplitude.parameters[h_dec] = 1
205205
else:
206-
parameter_defaults[scaling_factors] = 1 + 0j
206+
amplitude.parameters[scaling_factors] = 1 + 0j
207207
scaling_factors = _create_scaling_factors(
208208
chain,
209209
(self.use_production_helicity_couplings, λR, λ[k]),
210210
(self.use_decay_helicity_couplings, λ[i], λ[j]),
211211
one_scalar_per_chain=use_coefficients,
212212
)
213-
sub_amp_expr = (
213+
amplitude *= (
214214
sp.KroneckerDelta(λ[0], λR - λ[k])
215215
* (-1) ** (spin[k] - λ[k])
216-
* dynamics.expression
217216
* Wigner.d(resonance_spin, λR, λ[i] - λ[j], θij)
218217
* _product(scaling_factors)
219218
* (-1) ** (spin[j] - λ[j])
220219
)
221220
if not self.use_decay_helicity_couplings:
222-
sub_amp_expr *= _formulate_clebsch_gordan_factors(
221+
amplitude *= _formulate_clebsch_gordan_factors(
223222
chain.decay_node,
224223
helicities={
225224
self.decay.final_state[i]: λ[i],
@@ -228,27 +227,25 @@ def formulate_subsystem_amplitude( # noqa: PLR0914
228227
)
229228
if not self.use_production_helicity_couplings:
230229
production_isobar = chain.decay
231-
sub_amp_expr *= _formulate_clebsch_gordan_factors(
230+
amplitude *= _formulate_clebsch_gordan_factors(
232231
production_isobar,
233232
helicities={
234233
chain.resonance: λR,
235234
self.decay.final_state[k]: λ[k],
236235
},
237236
)
238-
sub_amp = PoolSum(
239-
sub_amp_expr,
240-
(λR, resonance_helicities),
237+
amplitude_sum += attrs.evolve(
238+
amplitude,
239+
expression=PoolSum(amplitude.expression, (λR, resonance_helicities)),
241240
)
242-
terms.append(sub_amp)
243241
A = _generate_amplitude_index_bases()
244242
amp_symbol = A[subsystem_id][λ0, λ1, λ2, λ3]
245-
amp_expr = sp.Add(*terms)
246243
return AmplitudeModel(
247244
decay=self.decay,
248245
intensity=sp.Abs(amp_symbol) ** 2,
249-
amplitudes={amp_symbol: amp_expr},
250-
variables={θij: θij_expr},
251-
parameter_defaults=parameter_defaults,
246+
amplitudes={amp_symbol: amplitude_sum.expression},
247+
variables=amplitude_sum.subexpressions | {θij: θij_expr},
248+
parameter_defaults=amplitude_sum.parameters,
252249
)
253250

254251
def formulate_aligned_amplitude(
@@ -480,27 +477,43 @@ class DynamicsBuilder(Protocol):
480477
def __call__(self, decay_chain: ThreeBodyDecayChain) -> DefinedExpression: ...
481478

482479

480+
def _binary_operation(op: Callable[[Any, Any], Any]):
481+
def decorator(func):
482+
@wraps(func)
483+
def wrapper(self: DefinedExpression, other):
484+
if isinstance(other, DefinedExpression):
485+
return DefinedExpression(
486+
expression=op(self.expression, other.expression),
487+
parameters=self.parameters | other.parameters,
488+
subexpressions=self.subexpressions | other.subexpressions,
489+
)
490+
return DefinedExpression(
491+
expression=op(self.expression, other),
492+
parameters=self.parameters,
493+
subexpressions=self.subexpressions,
494+
)
495+
496+
return wrapper
497+
498+
return decorator
499+
500+
483501
@define
484502
class DefinedExpression:
485-
expression: sp.Expr = sp.S.One
503+
expression: sp.Expr = field(converter=sp.sympify, default=sp.S.One) # type:ignore[misc]
486504
parameters: dict[sp.Symbol, complex | float] = field(factory=dict)
487-
488-
def __mul__(self, other: Any) -> DefinedExpression:
489-
if isinstance(other, DefinedExpression):
490-
return DefinedExpression(
491-
expression=self.expression * other.expression,
492-
parameters={**self.parameters, **other.parameters},
493-
)
494-
if isinstance(other, abc.Sequence) and len(other) == 2: # noqa: PLR2004
495-
expression, definitions = other
496-
return DefinedExpression(
497-
expression=self.expression * expression,
498-
parameters={**self.parameters, **definitions},
499-
)
500-
return DefinedExpression(
501-
expression=self.expression * other,
502-
parameters=self.parameters,
503-
)
505+
subexpressions: dict[sp.Symbol, sp.Expr] = field(factory=dict)
506+
507+
@_binary_operation(operator.mul)
508+
def __mul__(self, other) -> DefinedExpression: ... # type:ignore[empty-body]
509+
@_binary_operation(operator.add)
510+
def __add__(self, other) -> DefinedExpression: ... # type:ignore[empty-body]
511+
@_binary_operation(operator.sub)
512+
def __sub__(self, other) -> DefinedExpression: ... # type:ignore[empty-body]
513+
@_binary_operation(operator.truediv)
514+
def __truediv__(self, other) -> DefinedExpression: ... # type:ignore[empty-body]
515+
@_binary_operation(operator.pow)
516+
def __pow__(self, other) -> DefinedExpression: ... # type:ignore[empty-body]
504517

505518

506519
def create_mass_symbol_mapping(decay: ThreeBodyDecay) -> dict[sp.Symbol, float]:

0 commit comments

Comments
 (0)