diff --git a/.github/workflows/ubuntu.yml b/.github/workflows/ubuntu.yml index 31d135317..a6eabd51e 100644 --- a/.github/workflows/ubuntu.yml +++ b/.github/workflows/ubuntu.yml @@ -33,6 +33,11 @@ jobs: - uses: actions/checkout@v6 + - name: Setup Python + uses: actions/setup-python@v6 + with: + python-version: "3.11" + - name: install dependencies run: | .github/workflows/dependencies/gcc-openmpi.sh diff --git a/cmake/dependencies/ABLASTR.cmake b/cmake/dependencies/ABLASTR.cmake index 6e4d246d0..4496d82a5 100644 --- a/cmake/dependencies/ABLASTR.cmake +++ b/cmake/dependencies/ABLASTR.cmake @@ -178,7 +178,7 @@ set(ImpactX_openpmd_src "" set(ImpactX_ablastr_repo "https://github.com/BLAST-WarpX/warpx.git" CACHE STRING "Repository URI to pull and build ABLASTR from if(ImpactX_ablastr_internal)") -set(ImpactX_ablastr_branch "26.03" +set(ImpactX_ablastr_branch "af0a5227df5b7aa3366aca530a9669d3366b1f3a" CACHE STRING "Repository branch for ImpactX_ablastr_repo if(ImpactX_ablastr_internal)") @@ -186,7 +186,7 @@ set(ImpactX_ablastr_branch "26.03" set(ImpactX_amrex_repo "https://github.com/AMReX-Codes/amrex.git" CACHE STRING "Repository URI to pull and build AMReX from if(ImpactX_amrex_internal)") -set(ImpactX_amrex_branch "26.03" +set(ImpactX_amrex_branch "9219ba416b7ba2073dd1b12bf19fdce27391f17b" CACHE STRING "Repository branch for ImpactX_amrex_repo if(ImpactX_amrex_internal)") diff --git a/cmake/dependencies/pyAMReX.cmake b/cmake/dependencies/pyAMReX.cmake index dcfecd8e5..61c26cc1d 100644 --- a/cmake/dependencies/pyAMReX.cmake +++ b/cmake/dependencies/pyAMReX.cmake @@ -39,6 +39,12 @@ function(find_pyamrex) if(ImpactX_pyamrex_internal OR ImpactX_pyamrex_src) set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + # safe compile time + set(pyAMReX_CODES "ImpactX" CACHE INTERNAL "Fine-tune the pre-compiled particle containers for downstream codes") + + # skip pyAMReX's own tests (e.g., pytest.AMReX) in the ImpactX superbuild + set(pyAMReX_BUILD_TESTING OFF CACHE BOOL "Run the pyAMReX tests" FORCE) + if(ImpactX_pyamrex_src) add_subdirectory(${ImpactX_pyamrex_src} _deps/localpyamrex-build/) else() @@ -59,7 +65,7 @@ function(find_pyamrex) endif() elseif(NOT ImpactX_pyamrex_internal) # TODO: MPI control - find_package(pyAMReX 26.03 CONFIG REQUIRED) + find_package(pyAMReX 26.03 CONFIG REQUIRED COMPONENTS CODES_ImpactX) message(STATUS "pyAMReX: Found version '${pyAMReX_VERSION}'") endif() endfunction() @@ -74,7 +80,7 @@ option(ImpactX_pyamrex_internal "Download & build pyAMReX" ON) set(ImpactX_pyamrex_repo "https://github.com/AMReX-Codes/pyamrex.git" CACHE STRING "Repository URI to pull and build pyamrex from if(ImpactX_pyamrex_internal)") -set(ImpactX_pyamrex_branch "26.03" +set(ImpactX_pyamrex_branch "edf12bfc6ab5426a4b206e40f988afbadb93d437" CACHE STRING "Repository branch for ImpactX_pyamrex_repo if(ImpactX_pyamrex_internal)") diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index ff99e2566..d35724c0d 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -720,6 +720,19 @@ add_impactx_test(rfcavity-ref-part-hook.py OFF # no plot script yet ) +# Proton acceleration by RF Cavities (using phase optimization) ##################### +# +file(COPY ${ImpactX_SOURCE_DIR}/examples/rfcavity/onaxis_data.in + DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/rfcavity-ref-part-opt.py) +file(COPY ${ImpactX_SOURCE_DIR}/examples/rfcavity/phase_opt.py + DESTINATION ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/rfcavity-ref-part-opt.py) +add_impactx_test(rfcavity-ref-part-opt.py + examples/rfcavity/run_rfcavity_ref_part_opt.py + OFF # ImpactX MPI-parallel + examples/rfcavity/analysis_rfcavity_ref_part_opt.py + OFF # no plot script yet +) + # Ideal, Hard-Edge Solenoid ################################################### # add_impactx_test(solenoid diff --git a/examples/rfcavity/README.rst b/examples/rfcavity/README.rst index 735f782f1..9c63c8298 100644 --- a/examples/rfcavity/README.rst +++ b/examples/rfcavity/README.rst @@ -160,3 +160,59 @@ We run the following script to analyze correctness: .. literalinclude:: analysis_rfcavity_ref_part_hook.py :language: python3 :caption: You can copy this file from ``examples/rfcavity/analysis_rfcavity_ref_part_hook.py``. + + + +.. _examples-rfcavity-ref-part-opt: + +Proton acceleration by RF Cavities (Using RF phase optimization) +================================================================ + +This test is similar to the :ref:`test above `, except that the test is performed for protons (at the same energy). As a result, there is a large variation of relativistic +beta (in the first RF cavity), and the analytical methods of the previous example cannot be used. + +To treat this case, the callback hook feature :py:attr:`~impactx.ImpactX.hook` is combined with an optimization loop (using scipy.minimize_scalar) to determine, for each RF cavity, the input phase setting +that results in maximum energy gain. The phase is reset to this value (for each cavity). + +The functions used for optimization of the RF phase are contained in the file ``phase_opt.py``. + +In this test, the initial and final reference values of :math:`s` and :math:`\gamma` must agree with nominal values. + + +Run +--- + +This example can be run as: + +* **Python** script: ``python3 run_rfcavity_ref_part_opt.py`` + +For `MPI-parallel `__ runs, prefix these lines with ``mpiexec -n 4 ...`` or ``srun -n 4 ...``, depending on the system. + +.. tab-set:: + + .. tab-item:: Python: Script + + .. literalinclude:: run_rfcavity_ref_part_opt.py + :language: python3 + :caption: You can copy this file from ``examples/rfcavity/run_rfcavity_ref_part_opt.py``. + +This script makes use of functions defined in the script ``phase_opt.py`` below, which can be re-used by the user: + +.. tab-set:: + + .. tab-item:: Python: Script + + .. literalinclude:: phase_opt.py + :language: python3 + :caption: You can copy this file from ``examples/rfcavity/phase_opt.py``. + +Analyze +------- + +We run the following script to analyze correctness: + +.. dropdown:: Script ``analysis_rfcavity_ref_part_opt.py`` + + .. literalinclude:: analysis_rfcavity_ref_part_opt.py + :language: python3 + :caption: You can copy this file from ``examples/rfcavity/analysis_rfcavity_ref_part_opt.py``. diff --git a/examples/rfcavity/analysis_rfcavity_ref_part_opt.py b/examples/rfcavity/analysis_rfcavity_ref_part_opt.py new file mode 100755 index 000000000..1260b1b60 --- /dev/null +++ b/examples/rfcavity/analysis_rfcavity_ref_part_opt.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# +# Copyright 2022-2023 ImpactX contributors +# Authors: Axel Huebl, Chad Mitchell +# License: BSD-3-Clause-LBNL +# + +import glob +import re + +import numpy as np +import pandas as pd + + +def read_file(file_pattern): + for filename in glob.glob(file_pattern): + df = pd.read_csv(filename, delimiter=r"\s+") + if "step" not in df.columns: + step = int(re.findall(r"[0-9]+", filename)[0]) + df["step"] = step + yield df + + +def read_time_series(file_pattern): + """Read in all CSV files from each MPI rank (and potentially OpenMP + thread). Concatenate into one Pandas dataframe. + + Returns + ------- + pandas.DataFrame + """ + return pd.concat( + read_file(file_pattern), + axis=0, + ignore_index=True, + ) # .set_index('id') + + +# read reference particle data +rbc = read_time_series("diags/ref_particle.*") + +s = rbc["s"] +gamma = rbc["gamma"] + +si = s.iloc[0] +gammai = gamma.iloc[0] + +sf = s.iloc[-1] +gammaf = gamma.iloc[-1] + +print("") +print("Initial Beam:") +print(f" s_ref={si:e} gamma_ref={gammai:e}") + +atol = 1.0e-4 # ignored +print(f" atol={atol}") + +assert np.allclose( + [si, gammai], + [ + 0.000000, + 1.2451314527015738, + ], + atol=atol, +) + + +print("") +print("Final Beam:") +print(f" s_ref={sf:e} gamma_ref={gammaf:e}") + +atol = 1.0e-4 # ignored +print(f" atol={atol}") + +assert np.allclose( + [sf, gammaf], + [ + 5.9391682799999987, + 39.858859594214152, + ], + atol=atol, +) diff --git a/examples/rfcavity/phase_opt.py b/examples/rfcavity/phase_opt.py new file mode 100644 index 000000000..bd4fddc33 --- /dev/null +++ b/examples/rfcavity/phase_opt.py @@ -0,0 +1,78 @@ +import numpy as np +from scipy.optimize import minimize_scalar + +from impactx import push + + +def objective(parameter, ref, element): + """ + A function that is evaluated by the optimizer. + + Parameters + ---------- + parameter: + rf cavity phase + + Returns + ------- + Negative of the RF cavity energy gain in MeV (to minimize). + """ + + # adjust the RF cavity phase + phase_opt = parameter + element.phase = phase_opt + + # store the incoming energy and copy the reference particle + KE_in = ref.kin_energy_MeV + ref_copy = ref.copy() + + # push the copy of the reference particle + push(ref_copy, element) + KE_fin = ref_copy.kin_energy_MeV + + # evaluate the objective + loss = KE_in - KE_fin + + if np.isnan(loss): + loss = 1.0e99 + + return loss + + +def optimize(ref, element): + """ + Maximize the energy gain of a reference particle + + Using (ref) in an RFCavity (element): the optimization is performed by minimizing + -(change in kinetic energy in MeV). + + Parameters + ---------- + ref: + the reference particle at RF entry + element: + the RF cavity element + + Returns + ------- + The optimized phase and energy gain at that phase (phase_opt, e_gain). + """ + + # optimizer specific options + options = {"maxiter": 2000, "disp": 1} + + # Call the optimizer + res = minimize_scalar( + objective, + method="bounded", + args=(ref, element), + tol=1.0e-8, + options=options, + bounds=(-180, 180), + ) + + # Optimization result + phase_opt = res.x + e_gain = -1.0 * res.fun + + return phase_opt, e_gain diff --git a/examples/rfcavity/run_rfcavity_ref_part_opt.py b/examples/rfcavity/run_rfcavity_ref_part_opt.py new file mode 100755 index 000000000..4264d40b2 --- /dev/null +++ b/examples/rfcavity/run_rfcavity_ref_part_opt.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +# +# Copyright 2022-2023 ImpactX contributors +# Authors: Marco Garten, Axel Huebl, Chad Mitchell +# License: BSD-3-Clause-LBNL +# +# -*- coding: utf-8 -*- + +import numpy as np +from phase_opt import optimize + +from impactx import ImpactX, elements + +sim = ImpactX() + +# set numerical parameters and IO control +sim.space_charge = False +sim.slice_step_diagnostics = True + +# domain decomposition & space charge mesh +sim.init_grids() + +# reference kinetic energy (initial) +kin_energy_MeV = 230.0 # reference energy + +# reference particle +ref = sim.particle_container().ref_particle() +ref.set_charge_qe(1.0).set_mass_MeV(938.27208816).set_kin_energy_MeV(kin_energy_MeV) + +# design the accelerator lattice + +# access RF cavity on-axis field data +data_in = np.loadtxt("onaxis_data.in") +z = data_in[:, 0] +ez_onaxis = data_in[:, 1] +ncoef = 25 + +# Drift elements +dr1 = elements.Drift(name="dr1", ds=0.4, nslice=1) +dr2 = elements.Drift(name="dr2", ds=0.032997, nslice=1) + +# RF cavity element +rf = elements.RFCavity( + name="rf", + ds=1.31879807, + escale=20.0, + z=z, + field_on_axis=ez_onaxis, + ncoef=ncoef, + freq=1.3e9, + phase=0.0, + mapsteps=100, + nslice=4, +) + + +# add beam diagnostics +monitor = elements.BeamMonitor("monitor", backend="h5") + +sim.lattice.extend( + [ + monitor, + dr1, + dr2, + rf, + dr2, + dr2, + rf, + dr2, + dr2, + rf, + dr2, + dr2, + rf, + dr2, + monitor, + ] +) + + +def hook_before_element(sim): + element = sim.tracking_element + if type(element) is elements.RFCavity: + beam = sim.particle_container() + ref = beam.ref_particle() + print( + f" Beam at s={ref.s:.2f}m, t={ref.t:.2f}s, gamma at entry={ref.gamma:.2f}", + flush=True, + ) + + # Interpret input RF phase as measured relative to max accelerating phase: + phase_shift = element.phase + + # Find RF phase that maximizes energy gain: + phase_opt, e_gain = optimize(ref, element) + + # Reset input RF phase to appropriate value for tracking: + element.phase = phase_opt + phase_shift + print( + f" RF cavity updated (reset) values of phase={phase_opt:.2f}, energy gain (MeV) ={e_gain:.2f}", + flush=True, + ) + + +sim.hook["before_element"] = hook_before_element + +# run simulation +sim.track_reference(ref) + +# clean shutdown +sim.finalize() diff --git a/src/elements/ChrQuad.H b/src/elements/ChrQuad.H index 91126a7e3..3f9703ba6 100644 --- a/src/elements/ChrQuad.H +++ b/src/elements/ChrQuad.H @@ -397,13 +397,16 @@ lambday = -gyro_const * ( px*inv_delta1*(1_prt - cos_omega_ds) + x*omega*sin_ome amrex::ParticleReal const pt_ref = refpart.pt; amrex::ParticleReal const betgam2 = powi<2>(pt_ref) - 1_prt; + // normalize quad units to MAD-X convention if needed + amrex::ParticleReal const g = m_unit == 1 ? m_k / refpart.rigidity_Tm() : m_k; + // compute phase advance per unit length in s (in rad/m) - amrex::ParticleReal const omega = std::sqrt(std::abs(m_k)); + amrex::ParticleReal const omega = std::sqrt(std::abs(g)); // initialize linear map matrix elements Map6x6 R = Map6x6::Identity(); - if (m_k > 0.0) { + if (g > 0.0) { R(1,1) = std::cos(omega*slice_ds); R(1,2) = std::sin(omega*slice_ds)/omega; R(2,1) = -omega*std::sin(omega*slice_ds); @@ -413,7 +416,7 @@ lambday = -gyro_const * ( px*inv_delta1*(1_prt - cos_omega_ds) + x*omega*sin_ome R(4,3) = omega*std::sinh(omega*slice_ds); R(4,4) = std::cosh(omega*slice_ds); R(5,6) = slice_ds/betgam2; - } else if (m_k < 0.0) { + } else if (g < 0.0) { R(1,1) = std::cosh(omega*slice_ds); R(1,2) = std::sinh(omega*slice_ds)/omega; R(2,1) = omega*std::sinh(omega*slice_ds); diff --git a/src/initialization/InitParser.cpp b/src/initialization/InitParser.cpp index 8d4b0feab..56966d411 100644 --- a/src/initialization/InitParser.cpp +++ b/src/initialization/InitParser.cpp @@ -18,6 +18,11 @@ namespace impactx::initialization { amrex::ParmParse pp_amrex("amrex"); + // throw exceptions in asserts, to enable optional error handling, especially in Python + // https://amrex-codes.github.io/amrex/docs_html/RuntimeParameters.html#amrex.throw_exception + bool throw_exception = true; // AMReX' default: false + pp_amrex.queryAdd("throw_exception", throw_exception); + // https://amrex-codes.github.io/amrex/docs_html/GPU.html#inputs-parameters bool abort_on_out_of_gpu_memory = true; // AMReX' default: false pp_amrex.queryAdd("abort_on_out_of_gpu_memory", abort_on_out_of_gpu_memory); diff --git a/src/python/impactx/__init__.pyi b/src/python/impactx/__init__.pyi index 815046420..541e4b112 100644 --- a/src/python/impactx/__init__.pyi +++ b/src/python/impactx/__init__.pyi @@ -50,6 +50,7 @@ from impactx.impactx_pybind import ( push, wakeconvolution, ) +from impactx.impactx_pybind.elements import FilteredElementsList from impactx.madx_to_impactx import read_beam from . import ( @@ -66,6 +67,7 @@ __all__: list[str] = [ "Config", "CoordSystem", "Envelope", + "FilteredElementsList", "ImpactX", "ImpactXParConstIter", "ImpactXParIter", diff --git a/src/python/impactx/extensions/KnownElementsList.pyi b/src/python/impactx/extensions/KnownElementsList.pyi index eb23e29bc..4fce85a40 100644 --- a/src/python/impactx/extensions/KnownElementsList.pyi +++ b/src/python/impactx/extensions/KnownElementsList.pyi @@ -11,10 +11,14 @@ from __future__ import annotations import os as os import re as re +import weakref as weakref +import impactx.impactx_pybind.elements from impactx.impactx_pybind import elements +from impactx.impactx_pybind.elements import FilteredElementsList __all__: list[str] = [ + "FILTERED_ELEMENTS_LIST_INVALID_MSG", "FilteredElementsList", "count_by_kind", "elements", @@ -26,101 +30,9 @@ __all__: list[str] = [ "re", "register_KnownElementsList_extension", "select", + "weakref", ] -class FilteredElementsList: - """ - A selection result class for ElementsList that maintains references to original elements. - - References to the original elements in a lattice are needed to allow modification of the original elements. - """ - def __getitem__(self, key): ... - def __init__(self, original_list, indices): ... - def __iter__(self): ... - def __len__(self): ... - def __repr__(self): ... - def __str__(self): ... - def count_by_kind(self, kind_pattern) -> int: - """ - Count elements of a specific kind in the filtered list. - - Args: - kind_pattern: The element kind to count. Can be: - - String name (e.g., "Drift", "Quad") - supports exact match - - Regex pattern (e.g., r".*Quad") - supports pattern matching - - Element type (e.g., elements.Drift) - supports exact type match - - Returns: - int: Number of elements of the specified kind. - """ - def get_kinds(self) -> list[type]: - """ - Get all unique element kinds in the filtered list. - - Returns: - list[type]: List of unique element types (sorted by name). - """ - def has_kind(self, kind_pattern) -> bool: - """ - Check if filtered list contains elements of a specific kind. - - Args: - kind_pattern: The element kind to check for. Can be: - - String name (e.g., "Drift", "Quad") - supports exact match - - Regex pattern (e.g., r".*Quad") - supports pattern matching - - Element type (e.g., elements.Drift) - supports exact type match - - Returns: - bool: True if at least one element of the specified kind exists. - """ - def select(self, *, kind=None, name=None): - """ - Apply filtering to this filtered list. - - This method applies additional filtering to an already filtered list, - maintaining references to the original elements and enabling chaining. - - **Filtering Logic:** - - - **Within a single filter**: OR logic (e.g., ``kind=["Drift", "Quad"]`` matches Drift OR Quad) - - **Between different filters**: OR logic (e.g., ``kind="Quad", name="quad1"`` matches Quad OR named "quad1") - - **Chaining filters**: AND logic (e.g., ``lattice.select(kind="Drift").select(name="drift1")`` matches Drift AND named "drift1") - - :param kind: Element type(s) to filter by. Can be a single string/type or a list/tuple - of strings/types for OR-based filtering. String values support exact matches - and regex patterns. Examples: "Drift", r".*Quad", elements.Drift, ["Drift", r".*Quad"], [elements.Drift, elements.Quad] - :type kind: str or type or list[str | type] or tuple[str | type, ...] or None, optional - - :param name: Element name(s) to filter by. Can be a single string, regex pattern string, or - a list/tuple of strings and/or regex pattern strings for OR-based filtering. - Examples: "quad1", r"quad\\d+", ["quad1", "quad2"], [r"quad\\d+", "bend1"] - :type name: str or list[str] or tuple[str, ...] or None, optional - - :return: FilteredElementsList containing references to original elements - :rtype: FilteredElementsList - - :raises TypeError: If kind/name parameters have wrong types - - **Examples:** - - Additional filtering on already filtered results: - - .. code-block:: python - - drift_elements = lattice.select( - kind="Drift" - ) # or lattice.select(kind=elements.Drift) - first_drift = drift_elements.select( - name="drift1" - ) # Further filter drifts by name - quad_elements = lattice.select( - kind="Quad" - ) # or lattice.select(kind=elements.Quad) - strong_quads = quad_elements.select( - name=r"quad\\d+" - ) # Filter quads by regex pattern - """ - def _check_element_match(element, kind, name): """ Check if an element matches the given kind and name criteria. @@ -134,11 +46,46 @@ def _check_element_match(element, kind, name): bool: True if element matches any criteria (OR logic) """ +def _clone_element(template): + """ + Deep-clone a lattice element via ``to_dict`` (pybind elements are not copy.copy-able). + """ + +def _commit_lattice_rebuild(original, new_elements) -> None: + """ + Replace lattice contents with ``new_elements`` and invalidate all FilteredElementsList views. + """ + +def _drift_class_for_replace_with_drifts(model: str, old_el) -> type: + """ + Map ``model`` and ``old_el`` to the Drift / ChrDrift / ExactDrift class to insert. + + For ``model=="match"``, the class follows ``_model_key_from_element_typename``; otherwise + ``model`` must already be validated against ``_DRIFT_MODEL_CLASSES``. + """ + +def _invalidate_all_registered_views(lattice) -> None: + """ + Mark every registered FilteredElementsList for this lattice as invalid. + """ + def _is_regex_pattern(pattern: str) -> bool: """ Check if a string looks like a regex pattern by testing if it contains regex metacharacters. """ +def _make_drift_from_old( + cls, old_el, *, keep_name, keep_ds, keep_alignment, keep_aperture +): + """ + Build a drift (``cls`` is Drift / ChrDrift / ExactDrift) from thick-element fields on ``old_el``. + + When ``keep_ds`` is False, ``ds`` is set to 0. When ``keep_name`` is False, ``name`` is None. + If ``keep_ds`` is True but ``old_el`` has no ``ds`` attribute (thin element), ``ds`` defaults to 0. + When ``keep_alignment`` is True, copy dx/dy/rotation from the old element. + When ``keep_aperture`` is True, copy aperture_x/aperture_y from the old element. + """ + def _matches_kind_pattern(element, kind_pattern): """ Check if an element matches a kind pattern. @@ -168,6 +115,16 @@ def _matches_string(text: str, string_pattern: str) -> bool: Check if text matches a string pattern (exact match or regex). """ +def _model_key_from_element_typename(type_name: str) -> str: + """ + Return the drift-model key for an element class name (linear / paraxial / exact). + """ + +def _registry_for(lattice): + """ + Return the WeakSet of FilteredElementsList instances for this lattice. + """ + def _validate_select_parameters(kind, name): """ Validate parameters for select methods. @@ -233,7 +190,9 @@ def register_KnownElementsList_extension(kel): KnownElementsList helper methods """ -def select(self, *, kind=None, name=None) -> FilteredElementsList: +def select( + self, *, kind=None, name=None +) -> impactx.impactx_pybind.elements.FilteredElementsList: """ Filter elements by type and name with OR-based logic. @@ -324,3 +283,13 @@ def select(self, *, kind=None, name=None) -> FilteredElementsList: quad_elements[0].k = 1.5 # Modify first quad's strength # All modifications affect the original lattice elements """ + +FILTERED_ELEMENTS_LIST_INVALID_MSG: str = "This lattice selection is no longer valid because the lattice was modified; call select() again on the lattice." +_DRIFT_MODEL_CLASSES: dict = { + "linear": impactx.impactx_pybind.elements.Drift, + "paraxial": impactx.impactx_pybind.elements.ChrDrift, + "exact": impactx.impactx_pybind.elements.ExactDrift, +} +_filtered_views_by_lattice: ( + weakref.WeakKeyDictionary +) # value = diff --git a/src/python/impactx/impactx_pybind/elements/__init__.pyi b/src/python/impactx/impactx_pybind/elements/__init__.pyi index ae48564ee..c12e9782f 100644 --- a/src/python/impactx/impactx_pybind/elements/__init__.pyi +++ b/src/python/impactx/impactx_pybind/elements/__init__.pyi @@ -8,7 +8,6 @@ import collections.abc import typing import amrex.space3d.amrex_3d_pybind -import impactx.extensions.KnownElementsList import impactx.impactx_pybind from . import mixin, transformation @@ -31,6 +30,7 @@ __all__: list[str] = [ "ExactMultipole", "ExactQuad", "ExactSbend", + "FilteredElementsList", "Kicker", "KnownElementsList", "LinearMap", @@ -1344,6 +1344,135 @@ class ExactSbend(mixin.Named, mixin.Thick, mixin.Alignment, mixin.PipeAperture): @phi.setter def phi(self, arg1: typing.SupportsFloat) -> None: ... +class FilteredElementsList: + """ + Result of ``KnownElementsList.select(...)`` or chained ``.select()`` calls: a filtered + view of the same underlying lattice. + + Indexing (``self[i]``) returns elements from the original ``KnownElementsList``; changing + fields on those elements modifies the lattice in place. You can narrow the filter again with + ``.select(...)`` (AND logic between chained calls). After ``delete``, ``replace_each``, or + ``replace_with_drifts``, obtain a new selection from the lattice; earlier filter objects must + not be used. + """ + def __getitem__(self, key): ... + def __init__(self, original_list, indices): ... + def __iter__(self): ... + def __len__(self): ... + def __repr__(self): ... + def __str__(self): ... + def _require_valid(self) -> None: + """ + Raise if this view was invalidated after a lattice mutation. + """ + def count_by_kind(self, kind_pattern) -> int: + """ + Count elements of a specific kind in the filtered list. + + Args: + kind_pattern: The element kind to count. Can be: + - String name (e.g., "Drift", "Quad") - supports exact match + - Regex pattern (e.g., r".*Quad") - supports pattern matching + - Element type (e.g., elements.Drift) - supports exact type match + + Returns: + int: Number of elements of the specified kind. + """ + def delete(self) -> None: + """ + Remove selected elements from the underlying lattice. Invalidates this and all other + live selections on the same lattice. Returns None. + """ + def get_kinds(self) -> list[type]: + """ + Get all unique element kinds in the filtered list. + + Returns: + list[type]: List of unique element types (sorted by name). + """ + def has_kind(self, kind_pattern) -> bool: + """ + Check if filtered list contains elements of a specific kind. + + Args: + kind_pattern: The element kind to check for. Can be: + - String name (e.g., "Drift", "Quad") - supports exact match + - Regex pattern (e.g., r".*Quad") - supports pattern matching + - Element type (e.g., elements.Drift) - supports exact type match + + Returns: + bool: True if at least one element of the specified kind exists. + """ + def replace_each(self, element, *, keep_name=True, keep_ds=False): + """ + Replace each selected element with a copy of ``element``, optionally keeping name and + ``ds`` from the replaced element (``keep_ds`` defaults to False). Invalidates prior views; + returns a new selection over the same indices. + """ + def replace_with_drifts( + self, *, model="match", keep_alignment=True, keep_aperture=False + ): + """ + Replace each selected element with a drift of the matching physics family. + + When ``model="match"``: ``Exact*`` elements become ``ExactDrift``, ``Chr*`` elements + become ``ChrDrift``, and all other (linear) elements become ``Drift``. When + ``model`` is ``"linear"``, ``"paraxial"``, or ``"exact"``, every selected slot uses + that drift model. Names and segment length ``ds`` are always taken from the replaced + element. + + By default, alignment errors (dx, dy, rotation) are preserved and apertures are + cleared. Use ``keep_alignment=False`` to zero alignment errors, or + ``keep_aperture=True`` to preserve aperture_x/aperture_y. + """ + def select(self, *, kind=None, name=None): + """ + Apply filtering to this filtered list. + + This method applies additional filtering to an already filtered list, + maintaining references to the original elements and enabling chaining. + + **Filtering Logic:** + + - **Within a single filter**: OR logic (e.g., ``kind=["Drift", "Quad"]`` matches Drift OR Quad) + - **Between different filters**: OR logic (e.g., ``kind="Quad", name="quad1"`` matches Quad OR named "quad1") + - **Chaining filters**: AND logic (e.g., ``lattice.select(kind="Drift").select(name="drift1")`` matches Drift AND named "drift1") + + :param kind: Element type(s) to filter by. Can be a single string/type or a list/tuple + of strings/types for OR-based filtering. String values support exact matches + and regex patterns. Examples: "Drift", r".*Quad", elements.Drift, ["Drift", r".*Quad"], [elements.Drift, elements.Quad] + :type kind: str or type or list[str | type] or tuple[str | type, ...] or None, optional + + :param name: Element name(s) to filter by. Can be a single string, regex pattern string, or + a list/tuple of strings and/or regex pattern strings for OR-based filtering. + Examples: "quad1", r"quad\\d+", ["quad1", "quad2"], [r"quad\\d+", "bend1"] + :type name: str or list[str] or tuple[str, ...] or None, optional + + :return: FilteredElementsList containing references to original elements + :rtype: FilteredElementsList + + :raises TypeError: If kind/name parameters have wrong types + + **Examples:** + + Additional filtering on already filtered results: + + .. code-block:: python + + drift_elements = lattice.select( + kind="Drift" + ) # or lattice.select(kind=elements.Drift) + first_drift = drift_elements.select( + name="drift1" + ) # Further filter drifts by name + quad_elements = lattice.select( + kind="Quad" + ) # or lattice.select(kind=elements.Quad) + strong_quads = quad_elements.select( + name=r"quad\\d+" + ) # Filter quads by regex pattern + """ + class Kicker(mixin.Named, mixin.Thin, mixin.Alignment): def __init__( self, @@ -1678,9 +1807,7 @@ class KnownElementsList: """ Return and remove the last element of the list. """ - def select( - self, *, kind=None, name=None - ) -> impactx.extensions.KnownElementsList.FilteredElementsList: + def select(self, *, kind=None, name=None) -> FilteredElementsList: """ Filter elements by type and name with OR-based logic.