Skip to content

Commit 5f31088

Browse files
authored
ENH: support more types in normalize_state_ids() (#166)
* DX: add test for `normalize_state_ids()` with `ProblemSet` * ENH: support `normalize_state_ids()` with mutable transitions and `ProblemSet`
1 parent 655f29b commit 5f31088

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

src/ampform_dpd/adapter/qrules.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,14 @@
1010
import attrs
1111
import qrules
1212
from qrules.quantum_numbers import InteractionProperties
13-
from qrules.topology import EdgeType, FrozenTransition, NodeType
14-
from qrules.transition import ReactionInfo, StateTransition, Topology
13+
from qrules.topology import (
14+
EdgeType,
15+
FrozenTransition,
16+
MutableTransition,
17+
NodeType,
18+
Transition,
19+
)
20+
from qrules.transition import ProblemSet, ReactionInfo, Topology
1521

1622
from ampform_dpd.decay import (
1723
FinalStateID,
@@ -240,15 +246,28 @@ def _(obj: ReactionInfo) -> ReactionInfo:
240246
)
241247

242248

243-
@_impl_normalize_state_ids.register(FrozenTransition) # type:ignore[attr-defined]
244-
def _(obj: StateTransition) -> StateTransition:
249+
_Transition = TypeVar("_Transition", FrozenTransition, MutableTransition)
250+
251+
252+
@_impl_normalize_state_ids.register(FrozenTransition)
253+
@_impl_normalize_state_ids.register(MutableTransition)
254+
def _(obj: _Transition) -> _Transition:
245255
return attrs.evolve(
246256
obj,
247257
topology=_impl_normalize_state_ids(obj.topology),
248258
states={new: obj.states[old] for new, old in enumerate(sorted(obj.states))},
249259
)
250260

251261

262+
@_impl_normalize_state_ids.register(ProblemSet)
263+
def _(obj: ProblemSet) -> ProblemSet:
264+
return ProblemSet(
265+
initial_facts=_impl_normalize_state_ids(obj.initial_facts),
266+
solving_settings=_impl_normalize_state_ids(obj.solving_settings),
267+
topology=_impl_normalize_state_ids(obj.topology),
268+
)
269+
270+
252271
@_impl_normalize_state_ids.register(Topology) # type:ignore[attr-defined]
253272
def _(obj: Topology) -> Topology:
254273
mapping = {old: new for new, old in enumerate(sorted(obj.edges))}
@@ -260,7 +279,15 @@ def _(obj: abc.Iterable[T]) -> list[T]:
260279
return [_impl_normalize_state_ids(x) for x in obj]
261280

262281

263-
T = TypeVar("T", ReactionInfo, StateTransition, Topology)
282+
T = TypeVar(
283+
"T",
284+
FrozenTransition,
285+
MutableTransition,
286+
ProblemSet,
287+
ReactionInfo,
288+
Topology,
289+
Transition,
290+
)
264291
"""Type variable for the input and output of :func:`normalize_state_ids`."""
265292

266293

tests/adapter/test_qrules.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import TYPE_CHECKING, Callable, SupportsFloat
66

77
import pytest
8+
from qrules import InteractionType, StateTransitionManager
89

910
from ampform_dpd.adapter.qrules import (
1011
_convert_transition,
@@ -117,6 +118,21 @@ def test_normalize_state_ids_reaction(jpsi2pksigma_reaction: ReactionInfo):
117118
assert transition012.states[i] == transition123.states[i + 1]
118119

119120

121+
def test_normalize_state_ids_problem_set():
122+
stm = StateTransitionManager(
123+
initial_state=[("J/psi(1S)", [-1, +1])],
124+
final_state=["K0", "Sigma+", "p~"],
125+
allowed_intermediate_particles=["N(1700)", "Sigma(1750)"],
126+
formalism="helicity",
127+
mass_conservation_factor=0,
128+
)
129+
stm.set_allowed_interaction_types([InteractionType.STRONG, InteractionType.EM])
130+
problem_sets = stm.create_problem_sets()
131+
some_problem_set = normalize_state_ids(problem_sets[3600.0][0])
132+
assert set(some_problem_set.initial_facts.initial_states) == {0}
133+
assert set(some_problem_set.initial_facts.final_states) == {1, 2, 3}
134+
135+
120136
def test_permute_equal_final_states(
121137
a2pipipi_reaction: ReactionInfo,
122138
jpsi2pksigma_reaction: ReactionInfo,

0 commit comments

Comments
 (0)