diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 872de76..06e0332 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -3,9 +3,8 @@ on: push: branches: - main - tags: ['*'] + tags: '*' pull_request: - workflow_dispatch: concurrency: # Skip intermediate builds: always. # Cancel intermediate builds: only if it is a pull request build. @@ -13,29 +12,30 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - runs-on: ${{ matrix.os }} - timeout-minutes: 60 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read + name: Julia ${{ matrix.version }} - ${{ matrix.platform.os }} - ${{ matrix.platform.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.platform.os }} strategy: fail-fast: false matrix: + platform: + - os: ubuntu-latest + arch: x64 + - os: ubuntu-24.04-arm + arch: aarch64 version: - - '1.11' - - '1.6' - - 'pre' - os: - - ubuntu-latest - arch: - - x64 + - 'lts' + - '1' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} - arch: ${{ matrix.arch }} + arch: ${{ matrix.platform.arch }} - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v5 + with: + files: lcov.info + token: ${{ secrets.codecov_token }} diff --git a/.gitignore b/.gitignore index 230c5ed..54179f3 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,12 @@ -/Manifest*.toml +Manifest.toml +*.DS_Store +*.root +examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_fl.swp +.vscode/ + +# Misc files +.DS_Store +/notebooks/* +**/profile/* +/statprof/* +/debug/* diff --git a/Project.toml b/Project.toml index 42d45fd..e1956f0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,26 @@ name = "JetTaggingFCC" uuid = "12f500dc-52b1-4d85-9b70-29fa5347616c" authors = ["Harry and contributors"] -version = "1.0.0-DEV" +version = "0.1.0" + +[deps] +EDM4hep = "eb32b910-dde9-4347-8fce-cd6be3498f0c" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +JetReconstruction = "44e8cb2c-dfab-4825-9c70-d4808a591196" +LorentzVectorHEP = "f612022c-142a-473f-8cfd-a09cf3793c6c" +ONNXRunTime = "e034b28e-924e-41b2-b98f-d2bbeb830c6a" +PhysicalConstants = "5ad8b20f-a522-5ce9-bfc9-ddf1d5bda6ab" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] -julia = "1.6.7" +EDM4hep = "0.4.0" +JSON = "0.21" +JetReconstruction = "1.0" +LorentzVectorHEP = "0.1.6" +ONNXRunTime = "1.3.1" +PhysicalConstants = "0.2.4" +StructArrays = "0.6.18, 0.7" +julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/README.md b/README.md index 2c9d6e5..4196a5e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,22 @@ -# JetTaggingFCC +# JetTaggingFCC.jl [![Build Status](https://github.com/JuliaHEP/JetTaggingFCC.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaHEP/JetTaggingFCC.jl/actions/workflows/CI.yml?query=branch%3Amain) + +Julia package for jet flavour tagging at Future Circular Collider (FCC) +experiments using machine learning. + +## Overview + +JetTaggingFCC.jl provides tools for identifying the quark flavor content of jets +in high-energy physics experiments. It uses ONNX neural network models to +classify jets based on their constituent particles. + +## Features + +- Load and process EDM4hep event data to extract physics features from jet constituents +- Run ONNX neural network inference for flavour classification + +See the `examples/` directory for detailed usage examples. + +**Important** This package need the pre-release version of `JetReconstruction`, +so currently `Pkg.develop("JetReconstruction")` is required. diff --git a/examples/flavour-tagging/Project.toml b/examples/flavour-tagging/Project.toml new file mode 100644 index 0000000..c2667a1 --- /dev/null +++ b/examples/flavour-tagging/Project.toml @@ -0,0 +1,21 @@ +[deps] +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" +CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" +CodecZstd = "6b39b394-51ab-5f42-8807-6242bab2b4c2" +EDM4hep = "eb32b910-dde9-4347-8fce-cd6be3498f0c" +EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +JetReconstruction = "44e8cb2c-dfab-4825-9c70-d4808a591196" +JetTaggingFCC = "12f500dc-52b1-4d85-9b70-29fa5347616c" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +LorentzVectorHEP = "f612022c-142a-473f-8cfd-a09cf3793c6c" +LorentzVectors = "3f54b04b-17fc-5cd4-9758-90c048d965e3" +MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +ONNXRunTime = "e034b28e-924e-41b2-b98f-d2bbeb830c6a" +PhysicalConstants = "5ad8b20f-a522-5ce9-bfc9-ddf1d5bda6ab" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + +[sources] +JetTaggingFCC = {path = "../.."} diff --git a/examples/flavour-tagging/README.md b/examples/flavour-tagging/README.md new file mode 100644 index 0000000..b2415c3 --- /dev/null +++ b/examples/flavour-tagging/README.md @@ -0,0 +1,13 @@ +# Jet Reconstruction Flavour Tagging Module Examples + +## `simple-flavour-tagging.jl` + +```bash +julia --project simple-flavour-tagging.jl +``` +This script will perform a simple flavour tagging with one events only. +It will use the `FlavourTagging` module to tag the jets in the event and print the results. + +## `simple-flavour-tagging.ipynb` + +The same as the above example, but using a Jupyter notebook. diff --git a/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.json b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.json new file mode 100644 index 0000000..a3e8c6b --- /dev/null +++ b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.json @@ -0,0 +1,407 @@ +{ + "output_names": [ + "recojet_isG", + "recojet_isU", + "recojet_isS", + "recojet_isC", + "recojet_isB", + "recojet_isTAU", + "recojet_isD" + ], + "input_names": [ + "pf_points", + "pf_features", + "pf_vectors", + "pf_mask" + ], + "pf_points": { + "var_names": [ + "pfcand_thetarel", + "pfcand_phirel" + ], + "var_infos": { + "pfcand_thetarel": { + "median": 0.34896841645240784, + "norm_factor": 1.623847928023531, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phirel": { + "median": 0.00031830096850171685, + "norm_factor": 0.4692355255199169, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + } + }, + "var_length": 75 + }, + "pf_features": { + "var_names": [ + "pfcand_erel_log", + "pfcand_thetarel", + "pfcand_phirel", + "pfcand_dptdpt", + "pfcand_detadeta", + "pfcand_dphidphi", + "pfcand_dxydxy", + "pfcand_dzdz", + "pfcand_dxydz", + "pfcand_dphidxy", + "pfcand_dlambdadz", + "pfcand_dxyc", + "pfcand_dxyctgtheta", + "pfcand_phic", + "pfcand_phidz", + "pfcand_phictgtheta", + "pfcand_cdz", + "pfcand_cctgtheta", + "pfcand_mtof", + "pfcand_dndx", + "pfcand_charge", + "pfcand_isMu", + "pfcand_isEl", + "pfcand_isChargedHad", + "pfcand_isGamma", + "pfcand_isNeutralHad", + "pfcand_dxy", + "pfcand_dz", + "pfcand_btagSip2dVal", + "pfcand_btagSip2dSig", + "pfcand_btagSip3dVal", + "pfcand_btagSip3dSig", + "pfcand_btagJetDistVal", + "pfcand_btagJetDistSig", + "pfcand_type" + ], + "var_infos": { + "pfcand_erel_log": { + "median": -1.8002910614013672, + "norm_factor": 1.5911575382168794, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_thetarel": { + "median": 0.34896841645240784, + "norm_factor": 1.623847928023531, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phirel": { + "median": 0.00031830096850171685, + "norm_factor": 0.4692355255199169, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dptdpt": { + "median": -9.0, + "norm_factor": 0.11111111111111001, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_detadeta": { + "median": -9.0, + "norm_factor": 0.11111111039722958, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dphidphi": { + "median": -9.0, + "norm_factor": 0.11111107873678364, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxydxy": { + "median": -9.0, + "norm_factor": 0.11110580581124187, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dzdz": { + "median": -9.0, + "norm_factor": 0.11111063977819315, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxydz": { + "median": -9.0, + "norm_factor": 0.11111103621962409, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dphidxy": { + "median": -9.0, + "norm_factor": 0.11111118663623733, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dlambdadz": { + "median": -9.0, + "norm_factor": 0.11111111196002012, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxyc": { + "median": -9.0, + "norm_factor": 0.1111111110812715, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxyctgtheta": { + "median": -9.0, + "norm_factor": 0.11111111114037472, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phic": { + "median": -9.0, + "norm_factor": 0.11111111111115293, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phidz": { + "median": -9.0, + "norm_factor": 0.11111111110696342, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phictgtheta": { + "median": -9.0, + "norm_factor": 0.11111111049504692, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_cdz": { + "median": -9.0, + "norm_factor": 0.11111111111018916, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_cctgtheta": { + "median": -9.0, + "norm_factor": 0.11111111111112591, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_mtof": { + "median": 0.07032066211104393, + "norm_factor": 5.0446882868657825, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dndx": { + "median": 0.0, + "norm_factor": 0.7165380157879775, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_charge": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isMu": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isEl": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isChargedHad": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isGamma": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isNeutralHad": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxy": { + "median": -9.0, + "norm_factor": 0.11104036349544703, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dz": { + "median": -9.0, + "norm_factor": 0.1110838280661911, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagSip2dVal": { + "median": -9.0, + "norm_factor": 0.11100882796219, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagSip2dSig": { + "median": -9.0, + "norm_factor": 0.10339552523584994, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagSip3dVal": { + "median": -9.0, + "norm_factor": 0.11100379369461871, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagSip3dSig": { + "median": -9.0, + "norm_factor": 0.10327416627313275, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagJetDistVal": { + "median": -9.0, + "norm_factor": 0.11106689226469424, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagJetDistSig": { + "median": -9.0, + "norm_factor": 0.10776558597370688, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_type": { + "median": 22.0, + "norm_factor": 0.045454545454545456, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + } + }, + "var_length": 75 + }, + "pf_vectors": { + "var_names": [ + "pfcand_e", + "pfcand_p", + "pfcand_e", + "pfcand_e" + ], + "var_infos": { + "pfcand_e": { + "median": 0, + "norm_factor": 1, + "replace_inf_value": 0, + "lower_bound": -1e+32, + "upper_bound": 1e+32, + "pad": 0 + }, + "pfcand_p": { + "median": 0, + "norm_factor": 1, + "replace_inf_value": 0, + "lower_bound": -1e+32, + "upper_bound": 1e+32, + "pad": 0 + } + }, + "var_length": 75 + }, + "pf_mask": { + "var_names": [ + "pfcand_mask" + ], + "var_infos": { + "pfcand_mask": { + "median": 0, + "norm_factor": 1, + "replace_inf_value": 0, + "lower_bound": -1e+32, + "upper_bound": 1e+32, + "pad": 0 + } + }, + "var_length": 75 + } +} \ No newline at end of file diff --git a/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.onnx b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.onnx new file mode 100644 index 0000000..0fec730 Binary files /dev/null and b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.onnx differ diff --git a/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.yaml b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.yaml new file mode 100644 index 0000000..bc9fcfa --- /dev/null +++ b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc.yaml @@ -0,0 +1,430 @@ +treename: null +selection: null +test_time_selection: null +preprocess: + method: auto + data_fraction: 0.1 + params: + pfcand_thetarel: + length: 75 + pad_mode: wrap + center: 0.34896841645240784 + scale: 1.623847928023531 + min: -5 + max: 5 + pad_value: 0 + pfcand_phirel: + length: 75 + pad_mode: wrap + center: 0.00031830096850171685 + scale: 0.4692355255199169 + min: -5 + max: 5 + pad_value: 0 + pfcand_erel_log: + length: 75 + pad_mode: wrap + center: -1.8002910614013672 + scale: 1.5911575382168794 + min: -5 + max: 5 + pad_value: 0 + pfcand_dptdpt: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111111111001 + min: -5 + max: 5 + pad_value: 0 + pfcand_detadeta: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111039722958 + min: -5 + max: 5 + pad_value: 0 + pfcand_dphidphi: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111107873678364 + min: -5 + max: 5 + pad_value: 0 + pfcand_dxydxy: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11110580581124187 + min: -5 + max: 5 + pad_value: 0 + pfcand_dzdz: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111063977819315 + min: -5 + max: 5 + pad_value: 0 + pfcand_dxydz: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111103621962409 + min: -5 + max: 5 + pad_value: 0 + pfcand_dphidxy: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111118663623733 + min: -5 + max: 5 + pad_value: 0 + pfcand_dlambdadz: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111196002012 + min: -5 + max: 5 + pad_value: 0 + pfcand_dxyc: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.1111111110812715 + min: -5 + max: 5 + pad_value: 0 + pfcand_dxyctgtheta: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111114037472 + min: -5 + max: 5 + pad_value: 0 + pfcand_phic: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111111115293 + min: -5 + max: 5 + pad_value: 0 + pfcand_phidz: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111110696342 + min: -5 + max: 5 + pad_value: 0 + pfcand_phictgtheta: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111049504692 + min: -5 + max: 5 + pad_value: 0 + pfcand_cdz: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111111018916 + min: -5 + max: 5 + pad_value: 0 + pfcand_cctgtheta: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11111111111112591 + min: -5 + max: 5 + pad_value: 0 + pfcand_mtof: + length: 75 + pad_mode: wrap + center: 0.07032066211104393 + scale: 5.0446882868657825 + min: -5 + max: 5 + pad_value: 0 + pfcand_dndx: + length: 75 + pad_mode: wrap + center: 0.0 + scale: 0.7165380157879775 + min: -5 + max: 5 + pad_value: 0 + pfcand_charge: + length: 75 + pad_mode: wrap + center: 0.0 + scale: 1.0 + min: -5 + max: 5 + pad_value: 0 + pfcand_isMu: + length: 75 + pad_mode: wrap + center: 0.0 + scale: 1.0 + min: -5 + max: 5 + pad_value: 0 + pfcand_isEl: + length: 75 + pad_mode: wrap + center: 0.0 + scale: 1.0 + min: -5 + max: 5 + pad_value: 0 + pfcand_isChargedHad: + length: 75 + pad_mode: wrap + center: 0.0 + scale: 1.0 + min: -5 + max: 5 + pad_value: 0 + pfcand_isGamma: + length: 75 + pad_mode: wrap + center: 0.0 + scale: 1.0 + min: -5 + max: 5 + pad_value: 0 + pfcand_isNeutralHad: + length: 75 + pad_mode: wrap + center: 0.0 + scale: 1.0 + min: -5 + max: 5 + pad_value: 0 + pfcand_dxy: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11104036349544703 + min: -5 + max: 5 + pad_value: 0 + pfcand_dz: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.1110838280661911 + min: -5 + max: 5 + pad_value: 0 + pfcand_btagSip2dVal: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11100882796219 + min: -5 + max: 5 + pad_value: 0 + pfcand_btagSip2dSig: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.10339552523584994 + min: -5 + max: 5 + pad_value: 0 + pfcand_btagSip3dVal: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11100379369461871 + min: -5 + max: 5 + pad_value: 0 + pfcand_btagSip3dSig: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.10327416627313275 + min: -5 + max: 5 + pad_value: 0 + pfcand_btagJetDistVal: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.11106689226469424 + min: -5 + max: 5 + pad_value: 0 + pfcand_btagJetDistSig: + length: 75 + pad_mode: wrap + center: -9.0 + scale: 0.10776558597370688 + min: -5 + max: 5 + pad_value: 0 + pfcand_type: + length: 75 + pad_mode: wrap + center: 22.0 + scale: 0.045454545454545456 + min: -5 + max: 5 + pad_value: 0 + pfcand_e: + length: 75 + pad_mode: wrap + center: null + scale: 1 + min: -5 + max: 5 + pad_value: 0 + pfcand_p: + length: 75 + pad_mode: wrap + center: null + scale: 1 + min: -5 + max: 5 + pad_value: 0 + pfcand_mask: + length: 75 + pad_mode: constant + center: null + scale: 1 + min: -5 + max: 5 + pad_value: 0 +new_variables: + pfcand_mask: ak.ones_like(pfcand_e) +inputs: + pf_points: + pad_mode: wrap + length: 75 + vars: + - pfcand_thetarel + - pfcand_phirel + pf_features: + pad_mode: wrap + length: 75 + vars: + - pfcand_erel_log + - pfcand_thetarel + - pfcand_phirel + - pfcand_dptdpt + - pfcand_detadeta + - pfcand_dphidphi + - pfcand_dxydxy + - pfcand_dzdz + - pfcand_dxydz + - pfcand_dphidxy + - pfcand_dlambdadz + - pfcand_dxyc + - pfcand_dxyctgtheta + - pfcand_phic + - pfcand_phidz + - pfcand_phictgtheta + - pfcand_cdz + - pfcand_cctgtheta + - pfcand_mtof + - pfcand_dndx + - pfcand_charge + - pfcand_isMu + - pfcand_isEl + - pfcand_isChargedHad + - pfcand_isGamma + - pfcand_isNeutralHad + - pfcand_dxy + - pfcand_dz + - pfcand_btagSip2dVal + - pfcand_btagSip2dSig + - pfcand_btagSip3dVal + - pfcand_btagSip3dSig + - pfcand_btagJetDistVal + - pfcand_btagJetDistSig + - pfcand_type + pf_vectors: + length: 75 + pad_mode: wrap + vars: + - - pfcand_e + - null + - - pfcand_p + - null + - - pfcand_e + - null + - - pfcand_e + - null + pf_mask: + length: 75 + pad_mode: constant + vars: + - - pfcand_mask + - null +labels: + type: simple + value: + - recojet_isG + - recojet_isU + - recojet_isS + - recojet_isC + - recojet_isB + - recojet_isTAU + - recojet_isD +observers: [] +monitor_variables: [] +weights: + use_precomputed_weights: false + reweight_method: flat + reweight_vars: + jet_phi: + - -10.0 + - 10.0 + jet_theta: + - -10.0 + - 10.0 + reweight_classes: + - recojet_isG + - recojet_isU + - recojet_isS + - recojet_isC + - recojet_isB + - recojet_isTAU + - recojet_isD + class_weights: + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + reweight_hists: + recojet_isG: + - - 0.803920567035675 + recojet_isU: + - - 0.7541366219520569 + recojet_isS: + - - 0.802278459072113 + recojet_isC: + - - 0.7855589985847473 + recojet_isB: + - - 0.7855589985847473 + recojet_isTAU: + - - 0.8999999761581421 + recojet_isD: + - - 0.7541366219520569 diff --git a/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc_v1.json b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc_v1.json new file mode 100644 index 0000000..83b4da0 --- /dev/null +++ b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc_v1.json @@ -0,0 +1,396 @@ +{ + "output_names": [ + "recojet_isG", + "recojet_isQ", + "recojet_isS", + "recojet_isC", + "recojet_isB" + ], + "input_names": [ + "pf_points", + "pf_features", + "pf_vectors", + "pf_mask" + ], + "pf_points": { + "var_names": [ + "pfcand_thetarel", + "pfcand_phirel" + ], + "var_infos": { + "pfcand_thetarel": { + "median": 0.3606320321559906, + "norm_factor": 1.6306203960800496, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phirel": { + "median": 0.0002556022081989795, + "norm_factor": 0.46948837373539853, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + } + }, + "var_length": 75 + }, + "pf_features": { + "var_names": [ + "pfcand_erel_log", + "pfcand_thetarel", + "pfcand_phirel", + "pfcand_dptdpt", + "pfcand_detadeta", + "pfcand_dphidphi", + "pfcand_dxydxy", + "pfcand_dzdz", + "pfcand_dxydz", + "pfcand_dphidxy", + "pfcand_dlambdadz", + "pfcand_dxyc", + "pfcand_dxyctgtheta", + "pfcand_phic", + "pfcand_phidz", + "pfcand_phictgtheta", + "pfcand_cdz", + "pfcand_cctgtheta", + "pfcand_mtof", + "pfcand_dndx", + "pfcand_charge", + "pfcand_isMu", + "pfcand_isEl", + "pfcand_isChargedHad", + "pfcand_isGamma", + "pfcand_isNeutralHad", + "pfcand_dxy", + "pfcand_dz", + "pfcand_btagSip2dVal", + "pfcand_btagSip2dSig", + "pfcand_btagSip3dVal", + "pfcand_btagSip3dSig", + "pfcand_btagJetDistVal", + "pfcand_btagJetDistSig" + ], + "var_infos": { + "pfcand_erel_log": { + "median": -1.8151949644088745, + "norm_factor": 1.6230400937923801, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_thetarel": { + "median": 0.3606320321559906, + "norm_factor": 1.6306203960800496, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phirel": { + "median": 0.0002556022081989795, + "norm_factor": 0.46948837373539853, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dptdpt": { + "median": -9.0, + "norm_factor": 0.11111111111110991, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_detadeta": { + "median": -9.0, + "norm_factor": 0.11111111034557275, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dphidphi": { + "median": -9.0, + "norm_factor": 0.11111107770056672, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxydxy": { + "median": -9.0, + "norm_factor": 0.1111054124522878, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dzdz": { + "median": -9.0, + "norm_factor": 0.1111105908437637, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxydz": { + "median": -9.0, + "norm_factor": 0.11111102914593461, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dphidxy": { + "median": -9.0, + "norm_factor": 0.11111119697916892, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dlambdadz": { + "median": -9.0, + "norm_factor": 0.11111111196585105, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxyc": { + "median": -9.0, + "norm_factor": 0.11111111107993017, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxyctgtheta": { + "median": -9.0, + "norm_factor": 0.11111111115695796, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phic": { + "median": -9.0, + "norm_factor": 0.11111111111116635, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phidz": { + "median": -9.0, + "norm_factor": 0.11111111111495718, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_phictgtheta": { + "median": -9.0, + "norm_factor": 0.11111111043836697, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_cdz": { + "median": -9.0, + "norm_factor": 0.11111111111012681, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_cctgtheta": { + "median": -9.0, + "norm_factor": 0.1111111111111261, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_mtof": { + "median": 0.07332593947649002, + "norm_factor": 5.140589694831398, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dndx": { + "median": 0.0, + "norm_factor": 0.72003525410647, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_charge": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isMu": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isEl": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isChargedHad": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isGamma": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_isNeutralHad": { + "median": 0.0, + "norm_factor": 1.0, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dxy": { + "median": -9.0, + "norm_factor": 0.11103669093395387, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_dz": { + "median": -9.0, + "norm_factor": 0.11108313245378311, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagSip2dVal": { + "median": -9.0, + "norm_factor": 0.11100134875917528, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagSip2dSig": { + "median": -9.0, + "norm_factor": 0.10324104907259508, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagSip3dVal": { + "median": -9.0, + "norm_factor": 0.11099737353525808, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagSip3dSig": { + "median": -9.0, + "norm_factor": 0.10317010115501339, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagJetDistVal": { + "median": -9.0, + "norm_factor": 0.11106430427459978, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + }, + "pfcand_btagJetDistSig": { + "median": -9.0, + "norm_factor": 0.10770140600206164, + "replace_inf_value": 0, + "lower_bound": -5, + "upper_bound": 5, + "pad": 0 + } + }, + "var_length": 75 + }, + "pf_vectors": { + "var_names": [ + "pfcand_e", + "pfcand_p", + "pfcand_e", + "pfcand_e" + ], + "var_infos": { + "pfcand_e": { + "median": 0, + "norm_factor": 1, + "replace_inf_value": 0, + "lower_bound": -1e+32, + "upper_bound": 1e+32, + "pad": 0 + }, + "pfcand_p": { + "median": 0, + "norm_factor": 1, + "replace_inf_value": 0, + "lower_bound": -1e+32, + "upper_bound": 1e+32, + "pad": 0 + } + }, + "var_length": 75 + }, + "pf_mask": { + "var_names": [ + "pfcand_mask" + ], + "var_infos": { + "pfcand_mask": { + "median": 0, + "norm_factor": 1, + "replace_inf_value": 0, + "lower_bound": -1e+32, + "upper_bound": 1e+32, + "pad": 0 + } + }, + "var_length": 75 + } +} \ No newline at end of file diff --git a/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc_v1.onnx b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc_v1.onnx new file mode 100644 index 0000000..cf439a4 Binary files /dev/null and b/examples/flavour-tagging/data/wc_pt_7classes_12_04_2023/fccee_flavtagging_edm4hep_wc_v1.onnx differ diff --git a/examples/flavour-tagging/simple-flavour-tagging.ipynb b/examples/flavour-tagging/simple-flavour-tagging.ipynb new file mode 100644 index 0000000..917f2bd --- /dev/null +++ b/examples/flavour-tagging/simple-flavour-tagging.ipynb @@ -0,0 +1,319 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "21a84728", + "metadata": {}, + "outputs": [], + "source": [ + "using EDM4hep\n", + "using EDM4hep.RootIO\n", + "using LorentzVectorHEP\n", + "using JSON\n", + "using ONNXRunTime\n", + "using PhysicalConstants\n", + "using StructArrays\n", + "using JetReconstruction\n", + "using JetTaggingFCC" + ] + }, + { + "cell_type": "markdown", + "id": "55l0znnzdmv", + "metadata": {}, + "source": [ + "# Simple Jet Flavour Tagging Example\n", + "\n", + "This notebook demonstrates how to:\n", + "1. Load EDM4hep event data\n", + "2. Reconstruct jets using JetReconstruction\n", + "3. Extract features for flavour tagging\n", + "4. Run ONNX neural network inference\n", + "5. Get flavour probabilities for each jet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "379f4304", + "metadata": {}, + "outputs": [], + "source": [ + "# Paths to model files\n", + "model_dir = \"data/wc_pt_7classes_12_04_2023\"\n", + "onnx_path = joinpath(model_dir, \"fccee_flavtagging_edm4hep_wc_v1.onnx\")\n", + "json_path = joinpath(model_dir, \"fccee_flavtagging_edm4hep_wc_v1.json\")\n", + "\n", + "# Check if model files exist\n", + "if !isfile(onnx_path)\n", + " error(\"ONNX model not found at: $onnx_path\")\n", + "end\n", + "if !isfile(json_path)\n", + " error(\"JSON config not found at: $json_path\")\n", + "end\n", + "\n", + "println(\"Loading flavour tagging model...\")\n", + "model, config = JetTaggingFCC.setup_onnx_runtime(onnx_path, json_path)\n", + "\n", + "println(\"\\nThe model predicts these flavour classes:\")\n", + "for class_name in config[\"output_names\"]\n", + " println(\" - $class_name\")\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51c71156", + "metadata": {}, + "outputs": [], + "source": [ + "# Path to ROOT file with EDM4hep data\n", + "edm4hep_path = \"data/events_080263084.root\"\n", + "if !isfile(edm4hep_path)\n", + " error(\"EDM4hep data file not found at: $edm4hep_path\")\n", + "end\n", + "\n", + "println(\"\\nLoading EDM4hep data...\")\n", + "reader = RootIO.Reader(edm4hep_path)\n", + "events = RootIO.get(reader, \"events\")\n", + "println(\"Loaded $(length(events)) events\")\n", + "\n", + "# Process a specific event (event #12 as in the script)\n", + "event_id = 12\n", + "println(\"\\nProcessing event #$event_id\")\n", + "evt = events[event_id]\n", + "\n", + "# Get reconstructed particles and tracks\n", + "recps = RootIO.get(reader, evt, \"ReconstructedParticles\")\n", + "tracks = RootIO.get(reader, evt, \"EFlowTrack_1\")\n", + "\n", + "# Get MC particles and links for vertex information\n", + "mcps = RootIO.get(reader, evt, \"Particle\")\n", + "MCRecoLinks = RootIO.get(reader, evt, \"MCRecoAssociations\")\n", + "\n", + "# Extract MC vertices for each reconstructed particle\n", + "mc_vertices = Vector{LorentzVector{Float32}}(undef, length(recps))\n", + "reco_to_mc = Dict(link.rec_idx.index => link.sim_idx.index for link in MCRecoLinks)\n", + "for (rec_idx, mc_idx) in reco_to_mc\n", + " if rec_idx < length(recps) && mc_idx < length(mcps)\n", + " mc_vertices[rec_idx + 1] = LorentzVector(Float32(mcps[mc_idx + 1].vertex.x),\n", + " Float32(mcps[mc_idx + 1].vertex.y),\n", + " Float32(mcps[mc_idx + 1].vertex.z),\n", + " Float32(mcps[mc_idx + 1].time))\n", + " end\n", + "end\n", + "\n", + "# Fill any missing vertices with (0,0,0,0)\n", + "for i in 1:length(recps)\n", + " if !isassigned(mc_vertices, i)\n", + " mc_vertices[i] = LorentzVector(0.0f0, 0.0f0, 0.0f0, 0.0f0)\n", + " end\n", + "end\n", + "\n", + "# Get needed collections for feature extraction\n", + "bz = RootIO.get(reader, evt, \"magFieldBz\", register = false)[1]\n", + "trackdata = RootIO.get(reader, evt, \"EFlowTrack\")\n", + "trackerhits = RootIO.get(reader, evt, \"TrackerHits\")\n", + "gammadata = RootIO.get(reader, evt, \"EFlowPhoton\")\n", + "nhdata = RootIO.get(reader, evt, \"EFlowNeutralHadron\")\n", + "calohits = RootIO.get(reader, evt, \"CalorimeterHits\")\n", + "dNdx = RootIO.get(reader, evt, \"EFlowTrack_2\")\n", + "track_L = RootIO.get(reader, evt, \"EFlowTrack_L\", register = false)\n", + "\n", + "println(\" - $(length(recps)) reconstructed particles\")\n", + "println(\" - $(length(tracks)) tracks\")\n", + "println(\" - Magnetic field Bz = $bz T\")\n", + "\n", + "# Print the primary vertex that will be used\n", + "primary_vertex = LorentzVector(0.0f0, 0.0f0, 0.0f0, 0.0f0)\n", + "for vertex in mc_vertices\n", + " if vertex.x != 0.0 || vertex.y != 0.0 || vertex.z != 0.0\n", + " primary_vertex = vertex\n", + " break\n", + " end\n", + "end\n", + "println(\" - Primary vertex: ($(round(primary_vertex.x, digits=3)), $(round(primary_vertex.y, digits=3)), $(round(primary_vertex.z, digits=3))) mm\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f96fe81b", + "metadata": {}, + "outputs": [], + "source": [ + "# Reconstruct jets\n", + "println(\"\\nReconstructing jets...\")\n", + "cs = jet_reconstruct(recps; p = 1.0, R = 2.0, algorithm = JetAlgorithm.EEKt)\n", + "\n", + "# Get 2 exclusive jets\n", + "jets = exclusive_jets(cs; njets = 2, T = EEJet)\n", + "println(\"Found $(length(jets)) jets\")\n", + "\n", + "# Print jet properties\n", + "for (i, jet) in enumerate(jets)\n", + " println(\"\\nJet $i:\")\n", + " println(\" - Energy: $(round(jet.E, digits=2)) GeV\")\n", + " println(\" - Pt: $(round(JetReconstruction.pt(jet), digits=2)) GeV\")\n", + " println(\" - Eta: $(round(JetReconstruction.eta(jet), digits=3))\")\n", + " println(\" - Phi: $(round(JetReconstruction.phi(jet), digits=3))\")\n", + " println(\" - Mass: $(round(JetReconstruction.mass(jet), digits=2)) GeV\")\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a50c7e7", + "metadata": {}, + "outputs": [], + "source": [ + "# Get jet constituents\n", + "println(\"\\nExtracting jet constituents...\")\n", + "constituent_indices = [constituent_indexes(jet, cs) for jet in jets]\n", + "\n", + "jet_constituents = JetTaggingFCC.build_constituents_cluster(recps, constituent_indices)\n", + "\n", + "for (i, constituents) in enumerate(jet_constituents)\n", + " println(\" - Jet $i has $(length(constituents)) constituents\")\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f46c6c3", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract features for flavour tagging\n", + "println(\"\\nExtracting features for flavour tagging...\")\n", + "feature_data = JetTaggingFCC.extract_features(jets,\n", + " jet_constituents,\n", + " tracks,\n", + " bz,\n", + " track_L,\n", + " config,\n", + " trackdata,\n", + " trackerhits,\n", + " gammadata,\n", + " nhdata,\n", + " calohits,\n", + " dNdx,\n", + " mc_vertices)\n", + "\n", + "# Prepare input tensors\n", + "println(\"Preparing input tensors...\")\n", + "input_tensors = JetTaggingFCC.prepare_input_tensor(jet_constituents,\n", + " jets,\n", + " config,\n", + " feature_data)\n", + "\n", + "# Run inference\n", + "println(\"Running neural network inference...\")\n", + "weights = JetTaggingFCC.get_weights(0, # Thread slot\n", + " feature_data,\n", + " jets,\n", + " jet_constituents,\n", + " config,\n", + " model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95451c59", + "metadata": {}, + "outputs": [], + "source": [ + "# Extract and display results\n", + "println(\"\\n\" * \"=\"^60)\n", + "println(\"FLAVOUR TAGGING RESULTS\")\n", + "println(\"=\"^60)\n", + "\n", + "for (jet_idx, jet) in enumerate(jets)\n", + " println(\"\\nJet $jet_idx (E=$(round(jet.E, digits=1)) GeV, pT=$(round(JetReconstruction.pt(jet), digits=1)) GeV):\")\n", + " println(\"-\"^40)\n", + " \n", + " # Collect scores for this jet\n", + " scores = Float32[]\n", + " labels = String[]\n", + " \n", + " for (i, score_name) in enumerate(config[\"output_names\"])\n", + " score = JetTaggingFCC.get_weight(weights, i - 1)[jet_idx]\n", + " push!(scores, score)\n", + " push!(labels, score_name)\n", + " end\n", + " \n", + " # Sort by probability (descending)\n", + " sorted_indices = sortperm(scores, rev = true)\n", + " \n", + " # Display scores\n", + " for idx in sorted_indices\n", + " label = labels[idx]\n", + " score = scores[idx]\n", + " \n", + " # Handle NaN or invalid scores\n", + " if isnan(score) || isinf(score)\n", + " flavor_map = Dict(\"recojet_isG\" => \"Gluon \",\n", + " \"recojet_isQ\" => \"Light q \",\n", + " \"recojet_isS\" => \"Strange \",\n", + " \"recojet_isC\" => \"Charm \",\n", + " \"recojet_isB\" => \"Bottom \")\n", + " formatted_label = get(flavor_map, label, label)\n", + " println(\" $formatted_label: [Invalid score]\")\n", + " continue\n", + " end\n", + " \n", + " bar_length = Int(round(score * 30))\n", + " bar = \"█\"^bar_length\n", + " percentage = round(score * 100, digits = 1)\n", + " \n", + " # Format label\n", + " flavor_map = Dict(\"recojet_isG\" => \"Gluon \",\n", + " \"recojet_isQ\" => \"Light q \",\n", + " \"recojet_isS\" => \"Strange \",\n", + " \"recojet_isC\" => \"Charm \",\n", + " \"recojet_isB\" => \"Bottom \")\n", + " \n", + " formatted_label = get(flavor_map, label, label)\n", + " println(\" $formatted_label: $bar $(percentage)%\")\n", + " end\n", + " \n", + " # Identify most likely flavour\n", + " max_idx = argmax(scores)\n", + " max_label = labels[max_idx]\n", + " max_score = scores[max_idx]\n", + " \n", + " flavour_name = Dict(\"recojet_isG\" => \"gluon\",\n", + " \"recojet_isQ\" => \"light quark\",\n", + " \"recojet_isS\" => \"strange\",\n", + " \"recojet_isC\" => \"charm\",\n", + " \"recojet_isB\" => \"bottom\")[max_label]\n", + " \n", + " println(\"\\n → Most likely: $(flavour_name) ($(round(max_score * 100, digits=1))% confidence)\")\n", + "end\n", + "\n", + "println(\"\\n\" * \"=\"^60)\n", + "println(\"Processing complete!\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.11.6", + "language": "julia", + "name": "julia-1.11" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/flavour-tagging/simple-flavour-tagging.jl b/examples/flavour-tagging/simple-flavour-tagging.jl new file mode 100644 index 0000000..04e338f --- /dev/null +++ b/examples/flavour-tagging/simple-flavour-tagging.jl @@ -0,0 +1,271 @@ +#!/usr/bin/env julia + +""" +Simple Jet Flavour Tagging Example + +This script demonstrates how to: +1. Load EDM4hep event data +2. Reconstruct jets using JetReconstruction +3. Extract features for flavour tagging +4. Run ONNX neural network inference +5. Get flavour probabilities for each jet + +Run with: julia --project simple-flavour-tagging.jl +""" + +# Ensure we're using the development version of JetReconstruction +using Pkg +Pkg.develop("JetReconstruction") + +using EDM4hep +using EDM4hep.RootIO +using LorentzVectorHEP +using JSON +using ONNXRunTime +using PhysicalConstants +using StructArrays +using JetReconstruction +using JetTaggingFCC + +function main() + # Paths to model files + model_dir = "data/wc_pt_7classes_12_04_2023" + onnx_path = joinpath(model_dir, "fccee_flavtagging_edm4hep_wc_v1.onnx") + json_path = joinpath(model_dir, "fccee_flavtagging_edm4hep_wc_v1.json") + + # Check if model files exist + if !isfile(onnx_path) + error("ONNX model not found at: $onnx_path") + end + if !isfile(json_path) + error("JSON config not found at: $json_path") + end + + println("Loading flavour tagging model...") + model, config = JetTaggingFCC.setup_onnx_runtime(onnx_path, json_path) + + println("\nThe model predicts these flavour classes:") + for class_name in config["output_names"] + println(" - $class_name") + end + + # Path to ROOT file with EDM4hep data + edm4hep_path = "data/events_080263084.root" + # edm4hep_path = "/eos/experiment/fcc/ee/generation/DelphesEvents/winter2023/IDEA/wzp6_ee_nunuH_ecm240/events_080263084.root" + if !isfile(edm4hep_path) + error("EDM4hep data file not found at: $edm4hep_path") + end + + println("\nLoading EDM4hep data...") + reader = RootIO.Reader(edm4hep_path) + events = RootIO.get(reader, "events") + println("Loaded $(length(events)) events") + + # Process a specific event (event #12 as in the notebook) + event_id = 12 + println("\nProcessing event #$event_id") + evt = events[event_id] + + # Get reconstructed particles and tracks + recps = RootIO.get(reader, evt, "ReconstructedParticles") + tracks = RootIO.get(reader, evt, "EFlowTrack_1") + + # Get MC particles and links for vertex information + mcps = RootIO.get(reader, evt, "Particle") + MCRecoLinks = RootIO.get(reader, evt, "MCRecoAssociations") + + # Extract MC vertices for each reconstructed particle + mc_vertices = Vector{LorentzVector{Float32}}(undef, length(recps)) + reco_to_mc = Dict(link.rec_idx.index => link.sim_idx.index for link in MCRecoLinks) + for (rec_idx, mc_idx) in reco_to_mc + if rec_idx < length(recps) && mc_idx < length(mcps) + mc_vertices[rec_idx+1] = LorentzVector( + Float32(mcps[mc_idx+1].vertex.x), + Float32(mcps[mc_idx+1].vertex.y), + Float32(mcps[mc_idx+1].vertex.z), + Float32(mcps[mc_idx+1].time), + ) + end + end + + # Fill any missing vertices with (0,0,0,0) + for i = 1:length(recps) + if !isassigned(mc_vertices, i) + mc_vertices[i] = LorentzVector(0.0f0, 0.0f0, 0.0f0, 0.0f0) + end + end + + # Get needed collections for feature extraction + bz = RootIO.get(reader, evt, "magFieldBz", register = false)[1] + trackdata = RootIO.get(reader, evt, "EFlowTrack") + trackerhits = RootIO.get(reader, evt, "TrackerHits") + gammadata = RootIO.get(reader, evt, "EFlowPhoton") + nhdata = RootIO.get(reader, evt, "EFlowNeutralHadron") + calohits = RootIO.get(reader, evt, "CalorimeterHits") + dNdx = RootIO.get(reader, evt, "EFlowTrack_2") + track_L = RootIO.get(reader, evt, "EFlowTrack_L", register = false) + + println(" - $(length(recps)) reconstructed particles") + println(" - $(length(tracks)) tracks") + println(" - Magnetic field Bz = $bz T") + + # Print the primary vertex that will be used + primary_vertex = LorentzVector(0.0f0, 0.0f0, 0.0f0, 0.0f0) + for vertex in mc_vertices + if vertex.x != 0.0 || vertex.y != 0.0 || vertex.z != 0.0 + primary_vertex = vertex + break + end + end + println( + " - Primary vertex: ($(round(primary_vertex.x, digits=3)), $(round(primary_vertex.y, digits=3)), $(round(primary_vertex.z, digits=3))) mm", + ) + + # Reconstruct jets + println("\nReconstructing jets...") + cs = jet_reconstruct(recps; p = 1.0, R = 2.0, algorithm = JetAlgorithm.EEKt) + + # Get 2 exclusive jets + jets = exclusive_jets(cs, EEJet; njets = 2) + println("Found $(length(jets)) jets") + + # Print jet properties + for (i, jet) in enumerate(jets) + println("\nJet $i:") + println(" - Energy: $(round(jet.E, digits=2)) GeV") + println(" - Pt: $(round(JetReconstruction.pt(jet), digits=2)) GeV") + println(" - Eta: $(round(JetReconstruction.eta(jet), digits=3))") + println(" - Phi: $(round(JetReconstruction.phi(jet), digits=3))") + println(" - Mass: $(round(JetReconstruction.mass(jet), digits=2)) GeV") + end + + # Get jet constituents + println("\nExtracting jet constituents...") + constituent_indices = [constituent_indexes(jet, cs) for jet in jets] + + jet_constituents = JetTaggingFCC.build_constituents_cluster(recps, constituent_indices) + + for (i, constituents) in enumerate(jet_constituents) + println(" - Jet $i has $(length(constituents)) constituents") + end + + # Extract features for flavour tagging + println("\nExtracting features for flavour tagging...") + feature_data = JetTaggingFCC.extract_features( + jets, + jet_constituents, + tracks, + bz, + track_L, + config, + trackdata, + trackerhits, + gammadata, + nhdata, + calohits, + dNdx, + mc_vertices, + ) + + # Prepare input tensors + println("Preparing input tensors...") + input_tensors = + JetTaggingFCC.prepare_input_tensor(jet_constituents, jets, config, feature_data) + + # Run inference + println("Running neural network inference...") + weights = JetTaggingFCC.get_weights( + 0, # Thread slot + feature_data, + jets, + jet_constituents, + config, + model, + ) + + # Extract and display results + println("\n" * "="^60) + println("FLAVOUR TAGGING RESULTS") + println("="^60) + + for (jet_idx, jet) in enumerate(jets) + println( + "\nJet $jet_idx (E=$(round(jet.E, digits=1)) GeV, pT=$(round(JetReconstruction.pt(jet), digits=1)) GeV):", + ) + println("-"^40) + + # Collect scores for this jet + scores = Float32[] + labels = String[] + + for (i, score_name) in enumerate(config["output_names"]) + score = JetTaggingFCC.get_weight(weights, i - 1)[jet_idx] + push!(scores, score) + push!(labels, score_name) + end + + # Sort by probability (descending) + sorted_indices = sortperm(scores, rev = true) + + # Display scores + for idx in sorted_indices + label = labels[idx] + score = scores[idx] + + # Handle NaN or invalid scores + if isnan(score) || isinf(score) + flavor_map = Dict( + "recojet_isG" => "Gluon ", + "recojet_isQ" => "Light q ", + "recojet_isS" => "Strange ", + "recojet_isC" => "Charm ", + "recojet_isB" => "Bottom ", + ) + formatted_label = get(flavor_map, label, label) + println(" $formatted_label: [Invalid score]") + continue + end + + bar_length = Int(round(score * 30)) + bar = "█"^bar_length + percentage = round(score * 100, digits = 1) + + # Format label + flavor_map = Dict( + "recojet_isG" => "Gluon ", + "recojet_isQ" => "Light q ", + "recojet_isS" => "Strange ", + "recojet_isC" => "Charm ", + "recojet_isB" => "Bottom ", + ) + + formatted_label = get(flavor_map, label, label) + println(" $formatted_label: $bar $(percentage)%") + end + + # Identify most likely flavour + max_idx = argmax(scores) + max_label = labels[max_idx] + max_score = scores[max_idx] + + flavour_name = Dict( + "recojet_isG" => "gluon", + "recojet_isQ" => "light quark", + "recojet_isS" => "strange", + "recojet_isC" => "charm", + "recojet_isB" => "bottom", + )[max_label] + + println( + "\n → Most likely: $(flavour_name) ($(round(max_score * 100, digits=1))% confidence)", + ) + end + + println("\n" * "="^60) + println("Processing complete!") +end + +# Run the main function +if abspath(PROGRAM_FILE) == @__FILE__ + main() +end diff --git a/src/JetConstituentBuilder.jl b/src/JetConstituentBuilder.jl new file mode 100644 index 0000000..57f9ce0 --- /dev/null +++ b/src/JetConstituentBuilder.jl @@ -0,0 +1,71 @@ +# JetConstituentBuilder functions + +""" + build_constituents(jets::JetConstituents, + reco_particles::JetConstituents) -> Vector{JetConstituents} + +Build the collection of constituents (mapping jet -> reconstructed particles) for all jets in event. + +# Returns +A vector of JetConstituents, each containing the constituents for a specific jet. +""" +# TODO: Fix this function to be interpolate with Julia pipeline. Specificly, what would be the input jets? + +""" + build_constituents_cluster(reco_particles::JetConstituents, + indices::Vector{Vector{Int}}) -> Vector{JetConstituents} + +Build the collection of constituents using cluster indices. + +# Arguments +- reco_particles: a vector of `JetConstituents` representing reconstructed particles. +- indices: a vector of vectors, where each inner vector contains indices of particles for a specific cluster. + +# Returns +A vector of JetConstituents, each containing the constituents for a specific cluster. +""" +function build_constituents_cluster( + reco_particles::JetConstituents, + indices::Vector{Vector{Int64}}, +) + return map(jet_indices -> reco_particles[jet_indices], indices) +end + +""" + get_jet_constituents(constituents_collection::Vector{JetConstituents}, jet_index::Int) -> JetConstituents + +Retrieve the constituents of an indexed jet in the event. +# Arguments +- constituents_collection: constituents collection, a vector of `JetConstituents`. +- jet_index: the index of the jet for which to retrieve constituents (1-based index) + +# Returns +The constituents of the specified jet, or an empty collection if the jet index is invalid. +""" +function get_jet_constituents( + constituents_collection::Vector{JetConstituents}, + jet_index::Int, +) + return constituents_collection[jet_index] +end + +""" + get_constituents(constituents_collection::Vector{JetConstituents}, jet_indices::Vector{Int}) -> Vector{JetConstituents} + +Retrieve the constituents of a collection of indexed jets in the event. + +# Arguments +- constituents_collection: constituents collection, a vector of `JetConstituents`. +- jet_indices: a vector of jet indices (1-based index) for which to retrieve constituents. + +# Returns +A vector of `JetConstituents`, each containing the constituents for the specified jets. +""" +function get_constituents( + constituents_collection::Vector{JetConstituents}, + jet_indices::Vector{Int}, +) + # Filter valid indices and map to corresponding constituents + valid_indices = filter(idx -> 1 <= idx <= length(constituents_collection), jet_indices) + return map(idx -> constituents_collection[idx], valid_indices) +end diff --git a/src/JetConstituentUtils.jl b/src/JetConstituentUtils.jl new file mode 100644 index 0000000..4787c3d --- /dev/null +++ b/src/JetConstituentUtils.jl @@ -0,0 +1,2305 @@ +module JetConstituentUtils + +using EDM4hep +using JetReconstruction +using StructArrays: StructVector +using LorentzVectorHEP + +# Import physical constants +include("JetPhysicalConstants.jl") +using .JetPhysicalConstants + +const JetConstituents = StructVector{ReconstructedParticle,<:Any} +const JetConstituentsData = Vector{Float32} + +### Basic Kinematic (11) + +# get_pt - Transverse momentum +# get_p - Total momentum +# get_e - Energy +# get_mass - Mass +# get_type - Particle type +# get_charge - Electric charge +# get_theta - Polar angle +# get_phi - Azimuthal angle +# get_y - Rapidity +# get_eta - Pseudorapidity +# get_Bz - Magnetic field component + +""" + get_pt(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the transverse momentum of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents (each element contains particles for one jet) + +# Returns +A vector of vectors of transverse momentum values (sqrt(px^2 + py^2)) +""" +function get_pt(jets_constituents::Vector{<:JetConstituents}) + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + Float32[@inbounds sqrt(mom_x[i]^2 + mom_y[i]^2) for i in eachindex(mom_x)] + end for jet_constituents in jets_constituents + ] +end + +""" + get_p(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the momentum magnitude of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of momentum magnitudes (sqrt(px^2 + py^2 + pz^2)) +""" +function get_p(jets_constituents::Vector{<:JetConstituents}) + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + mom_z = jet_constituents.momentum.z + Float32[ + @inbounds sqrt(mom_x[i]^2 + mom_y[i]^2 + mom_z[i]^2) for + i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +""" + get_e(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the energy of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of energy values +""" +function get_e(jets_constituents::Vector{<:JetConstituents}) + return [jet_constituents.energy for jet_constituents in jets_constituents] +end + +""" + get_type(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the PDG type of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of particle types (PDG codes/Particle IDs) +""" +function get_type(jets_constituents::Vector{<:JetConstituents}) + return [jet_constituents.type for jet_constituents in jets_constituents] +end + +""" + get_mass(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the mass of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of mass values +""" +function get_mass(jets_constituents::Vector{<:JetConstituents}) + return [jet_constituents.mass for jet_constituents in jets_constituents] +end + +""" + get_charge(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the charge of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of charge values +""" +function get_charge(jets_constituents::Vector{<:JetConstituents}) + return [jet_constituents.charge for jet_constituents in jets_constituents] +end + +""" + get_theta(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the polar angle of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of polar angle values +""" +function get_theta(jets_constituents::Vector{<:JetConstituents}) + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + mom_z = jet_constituents.momentum.z + Float32[ + @inbounds( + let x = mom_x[i], y = mom_y[i], z = mom_z[i] + (x == 0.0f0 && y == 0.0f0 && z == 0.0f0) ? 0.0f0 : + atan(sqrt(x^2 + y^2), z) + end + ) for i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +""" + get_phi(jets_constituents::Vector{JetConstituents}) -> Vector{JetConstituentsData} + +Get the azimuthal angle of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of azimuthal angle values +""" +function get_phi(jets_constituents::Vector{<:JetConstituents}) + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + Float32[ + @inbounds(let x = mom_x[i], y = mom_y[i] + (x == 0.0f0 && y == 0.0f0) ? 0.0f0 : atan(y, x) + end) for i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +""" + get_y(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the rapidity of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of rapidity values +""" +function get_y(jets_constituents::Vector{<:JetConstituents}) + return [ + begin + energies = jet_constituents.energy + mom_z = jet_constituents.momentum.z + Float32[ + @inbounds(let e = energies[i], pz = mom_z[i] + 0.5f0 * log((e + pz) / (e - pz)) + end) for i in eachindex(energies) + ] + end for jet_constituents in jets_constituents + ] +end + +""" + get_eta(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Get the pseudorapidity of each particle in each jet. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +A vector of vectors of pseudorapidity values (eta = -ln(tan(theta/2))) +""" +function get_eta(jets_constituents::Vector{<:JetConstituents}) + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + mom_z = jet_constituents.momentum.z + Float32[ + @inbounds(let x = mom_x[i], y = mom_y[i], z = mom_z[i] + p = sqrt(x^2 + y^2 + z^2) + if p == 0.0f0 + 0.0f0 + elseif p == abs(z) # particle along beam axis + sign(z) * Inf32 + else + 0.5f0 * log((p + z) / (p - z)) + end + end) for i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +""" + get_Bz(jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Calculate the magnetic field Bz for each particle based on track curvature and momentum. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: Vector of track states (used to get the omega value) + +# Returns +A vector of vectors of Bz values. +""" +function get_Bz( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + a = C_LIGHT * MM_TO_M / FS_TO_S + n_tracks = length(tracks) + + # If tracks is a StructVector, we can access omega column directly + omega_values = tracks.omega + + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + charges = jet_constituents.charge + track_indices = jet_constituents.tracks + + Float32[ + @inbounds( + let track_idx = track_indices[i].first + if track_idx < n_tracks + pt = sqrt(mom_x[i]^2 + mom_y[i]^2) + omega_values[track_idx+1] / a * + pt * + copysign(1.0f0, charges[i]) + else + UNDEF_VAL + end + end + ) for i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +### Track Related Functions (5) + +## Track Parameter Transformations (XPtoPar) +# XPtoPar_dxy - Transformed transverse impact parameter +# XPtoPar_dz - Transformed longitudinal impact parameter +# XPtoPar_phi - Transformed azimuthal angle +# XPtoPar_C - Track curvature parameter +# XPtoPar_ct - c×tau parameter + +""" + get_dxy(jets_constituents::Vector{JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + V::LorentzVector, Bz::Float32) -> Vector{JetConstituentsData} + +Calculate the transverse impact parameter dxy for each particle in each jet relative to vertex V. +Reference: FCCAnalyses c++ function XPtoPar_dxy, adapted for jet constituents. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects +- V: LorentzVector representing the primary vertex +- Bz: The magnetic field in Tesla + +# Returns +Vector of vectors of dxy values (one vector per jet) +""" +function get_dxy( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + V::LorentzVector, + Bz::Float32, +) + cSpeed_Bz = C_LIGHT * NS_TO_S * Bz + n_tracks = length(tracks) + + Vx, Vy = Float32(V.x), Float32(V.y) + + D0_values = tracks.D0 + phi_values = tracks.phi + + n_jets = length(jets_constituents) + result = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + charges = jet_constituents.charge + track_indices = jet_constituents.tracks + n_particles = length(mom_x) + + # Single allocation per jet + dxy_values = Vector{Float32}(undef, n_particles) + + @simd for i = 1:n_particles + track_idx = track_indices[i].first + if track_idx < n_tracks + idx = track_idx + 1 + D0 = D0_values[idx] + phi0 = phi_values[idx] + + sin_phi, cos_phi = sincos(phi0) + x1 = -D0 * sin_phi - Vx + x2 = D0 * cos_phi - Vy + + px_val = mom_x[i] + py_val = mom_y[i] + + a = -charges[i] * cSpeed_Bz + pt = hypot(px_val, py_val) + r2 = x1^2 + x2^2 + cross = x1 * py_val - x2 * px_val + + # Compute impact parameter + discriminant = pt^2 - 2 * a * cross + a^2 * r2 + if discriminant > 0 + t = sqrt(discriminant) + if pt < 10.0f0 + dxy_values[i] = (t - pt) / a + else + dxy_values[i] = (-2 * cross + a * r2) / (t + pt) + end + else + dxy_values[i] = UNDEF_VAL + end + else + dxy_values[i] = UNDEF_VAL + end + end + + result[j] = dxy_values + end + + return result +end + +""" + get_dz(jets_constituents::Vector{JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + V::LorentzVector, Bz::Float32) -> Vector{JetConstituentsData} + +Calculate the longitudinal impact parameter dz for each particle in each jet relative to vertex V. +Reference: FCCAnalyses c++ function XPtoPar_dz, adapted for jet constituents. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects +- V: LorentzVector representing the primary vertex +- Bz: The magnetic field in Tesla + +# Returns +Vector of vectors of dz values (one vector per jet) +""" +function get_dz( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + V::LorentzVector, + Bz::Float32, +) + cSpeed_Bz = C_LIGHT * NS_TO_S * Bz + n_tracks = length(tracks) + + Vx, Vy, Vz = Float32(V.x), Float32(V.y), Float32(V.z) + + D0_values = tracks.D0 + Z0_values = tracks.Z0 + phi_values = tracks.phi + + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + mom_z = jet_constituents.momentum.z + charges = jet_constituents.charge + track_indices = jet_constituents.tracks + + Float32[ + @inbounds(let track_idx = track_indices[i].first + if track_idx < n_tracks + idx = track_idx + 1 + D0 = D0_values[idx] + Z0 = Z0_values[idx] + phi0 = phi_values[idx] + + sin_phi, cos_phi = sincos(phi0) + + x1 = -D0 * sin_phi - Vx + x2 = D0 * cos_phi - Vy + x3 = Z0 - Vz + + px = mom_x[i] + py = mom_y[i] + pz = mom_z[i] + + # Compute intermediate values + a = -charges[i] * cSpeed_Bz + pt = sqrt(px^2 + py^2) + c = a / (2 * pt) + r2 = x1^2 + x2^2 + cross = x1 * py - x2 * px + t = sqrt(pt^2 - 2 * a * cross + a^2 * r2) + + d = if pt < 10.0f0 + (t - pt) / a + else + (-2 * cross + a * r2) / (t + pt) + end + + b_arg = max(r2 - d^2, 0.0f0) / (1 + 2 * c * d) + b = c * sqrt(b_arg) + if abs(b) > 1.0f0 + b = sign(b) + end + + # Calculate st and ct + st = asin(b) / c + ct = pz / pt + + # Calculate z0 + dot = x1 * px + x2 * py + if dot > 0.0f0 + x3 - ct * st + else + x3 + ct * st + end + else + UNDEF_VAL + end + end) for i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +""" + get_phi0(jets_constituents::Vector{JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + V::LorentzVector, Bz::Float32) -> Vector{JetConstituentsData} + +Calculate the phi angle at the point of closest approach for each particle relative to vertex V. +This is a Julia implementation of the C++ function XPtoPar_phi. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects +- V: LorentzVector representing the primary vertex +- Bz: The magnetic field in Tesla + +# Returns +Vector of vectors of phi values (one vector per jet) +""" +function get_phi0( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + V::LorentzVector, + Bz::Float32, +) + cSpeed_Bz = C_LIGHT * NS_TO_S * Bz + n_tracks = length(tracks) + + Vx, Vy = Float32(V.x), Float32(V.y) + + D0_values = tracks.D0 + phi_values = tracks.phi + + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + charges = jet_constituents.charge + track_indices = jet_constituents.tracks + + Float32[ + @inbounds(let track_idx = track_indices[i].first + if track_idx < n_tracks + idx = track_idx + 1 + D0 = D0_values[idx] + phi0_track = phi_values[idx] + + sin_phi, cos_phi = sincos(phi0_track) + + x1 = -D0 * sin_phi - Vx + x2 = D0 * cos_phi - Vy + + px = mom_x[i] + py = mom_y[i] + + a = -charges[i] * cSpeed_Bz + + # Minimize redundant calculations + pt2 = px^2 + py^2 + r2 = x1^2 + x2^2 + cross = x1 * py - x2 * px + two_a_cross = 2 * a * cross + a2_r2 = a^2 * r2 + + t = sqrt(pt2 - two_a_cross + a2_r2) + inv_t = 1.0f0 / t + + a_x1 = a * x1 + a_x2 = a * x2 + atan((py - a_x1) * inv_t, (px + a_x2) * inv_t) + else + UNDEF_VAL + end + end) for i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +""" + get_c(jets_constituents::Vector{JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + Bz::Float32) -> Vector{JetConstituentsData} + +Calculate the track curvature for each particle in each jet. +Reference: FCCAnalyses c++ function XPtoPar_C, adapted for jet constituents. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects +- Bz: The magnetic field in Tesla + +# Returns +Vector of vectors of C values (one vector per jet) +""" +function get_c( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + Bz::Float32, +) + cSpeed_Bz_half = C_LIGHT * MM_TO_M / FS_TO_S * Bz * 0.5f0 + n_tracks = length(tracks) + + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + charges = jet_constituents.charge + track_indices = jet_constituents.tracks + + Float32[ + @inbounds(let track_idx = track_indices[i].first + if track_idx < n_tracks + px = mom_x[i] + py = mom_y[i] + inv_pt = 1.0f0 / sqrt(px^2 + py^2) + copysign(cSpeed_Bz_half * inv_pt, charges[i]) + else + UNDEF_VAL + end + end) for i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +""" + get_ct(jets_constituents::Vector{JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + Bz::Float32) -> Vector{JetConstituentsData} + +Calculate the c*tau for each particle in each jet. +Reference: FCCAnalyses c++ function XPtoPar_ct, adapted for jet constituents. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects +- Bz: The magnetic field in Tesla + +# Returns +Vector of vectors of ct values (one vector per jet) +""" +function get_ct( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, + Bz::Float32, +) + n_tracks = length(tracks) + + return [ + begin + mom_x = jet_constituents.momentum.x + mom_y = jet_constituents.momentum.y + mom_z = jet_constituents.momentum.z + track_indices = jet_constituents.tracks + + Float32[ + @inbounds(let track_idx = track_indices[i].first + if track_idx < n_tracks + px = mom_x[i] + py = mom_y[i] + pz = mom_z[i] + pt = sqrt(px^2 + py^2) + pz / pt + else + UNDEF_VAL + end + end) for i in eachindex(mom_x) + ] + end for jet_constituents in jets_constituents + ] +end + +### Covariance Matrix Elements (15) + +## Diagonal Elements +# get_omega_cov - Omega variance +# get_d0_cov - d0 variance +# get_z0_cov - z0 variance +# get_phi0_cov - phi0 variance +# get_tanlambda_cov - tanLambda variance + +## Off-diagonal Elements +# get_d0_z0_cov +# get_phi0_d0_cov +# get_phi0_z0_cov +# get_tanlambda_phi0_cov +# get_tanlambda_d0_cov +# get_tanlambda_z0_cov +# get_omega_tanlambda_cov +# get_omega_phi0_cov +# get_omega_d0_cov +# get_omega_z0_cov + +""" + get_dxydxy(jets_constituents::Vector{JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the d0 covariance (dxy/dxy) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dxydxy values +""" +function get_dxydxy( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[1], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_dphidxy(jets_constituents::Vector{JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the phi0-d0 covariance (dphi/dxy) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dphidxy values +""" +function get_dphidxy( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[2], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_dphidphi(jets_constituents::Vector{JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the phi covariance (dphi/dphi) for each particle in each jet from its associated track. +Reference: FCCAnalyses c++ function get_phi0_cov, adapted for jet constituents. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dphidphi values +""" +function get_dphidphi( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[3], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_dxyc(jets_constituents::Vector{JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the d0-omega covariance (dxy/c) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dxyc values +""" +function get_dxyc( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[4], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_phic(jets_constituents::Vector{JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the phi0-omega covariance (phi/c) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of phiomega values +""" +function get_phic( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[5], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_dptdpt(jets_constituents::Vector{JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the omega covariance (dpt/dpt) for each particle in each jet from its associated track. +Reference: FCCAnalyses c++ function get_omega_cov, adapted for jet constituents. + +# Arguments +- jets_constituents: Vector of jet constituents (each element contains particles for one jet) +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dptdpt values +""" +function get_dptdpt( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[6], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_dxydz(jets_constituents::Vector{JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the d0-z0 covariance (dxy/dz) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dxy/dz values +""" +function get_dxydz( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[7], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_phidz(jets_constituents::Vector{JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the phi0-z0 covariance (dphi/dz) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of phidz values +""" +function get_phidz( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[8], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_cdz(jets_constituents::Vector{JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the omega-z0 covariance (c/dz) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dxdz values +""" +function get_cdz( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[9], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_dzdz(jets_constituents::Vector{JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the z0 covariance (dz/dz) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of z0 covariance values +""" +function get_dzdz( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[10], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_dxyctgtheta(jets_constituents::Vector{<:JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the d0-tanLambda covariance (dxy/ctgtheta) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dxyctgtheta values +""" +function get_dxyctgtheta( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[11], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_phictgtheta(jets_constituents::Vector{<:JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the phi0-tanLambda covariance (phi/ctgtheta) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of phictgtheta values +""" +function get_phictgtheta( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[12], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_cctgtheta(jets_constituents::Vector{<:JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the omega-tanLambda covariance (c/ctgtheta) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of cctgtheta values +""" +function get_cctgtheta( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[13], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_dlambdadz(jets_constituents::Vector{<:JetConstituents}, tracks::Vector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the tanLambda-z0 covariance (dlambda/dz) for each particle. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of dlambdadz values +""" +function get_dlambdadz( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[14], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +""" + get_detadeta(jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}) -> Vector{JetConstituentsData} + +Get the tanLambda covariance (deta/deta) for each particle in each jet from its associated track. +Reference: FCCAnalyses c++ function get_tanlambda_cov, adapted for jet constituents. + +# Arguments +- jets_constituents: Vector of jet constituents +- tracks: StructVector of TrackState objects + +# Returns +Vector of vectors of detadeta values +""" +function get_detadeta( + jets_constituents::Vector{<:JetConstituents}, + tracks::StructVector{EDM4hep.TrackState}, +) + n_jets = length(jets_constituents) + n_tracks = length(tracks) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + track_indices = jet_constituents.tracks + n_particles = length(track_indices) + + jet_result = Vector{Float32}(undef, n_particles) + + @simd ivdep for i = 1:n_particles + track_idx = track_indices[i].first + jet_result[i] = + ifelse(track_idx < n_tracks, tracks[track_idx+1].covMatrix[15], UNDEF_VAL) + end + + result[j] = jet_result + end + + return result +end + +### Particle Identification (5) + +# get_is_mu - Check if constituent is muon +# get_is_el - Check if constituent is electron +# get_is_charged_had - Check if constituent is charged hadron +# get_is_gamma - Check if constituent is photon +# get_is_neutral_had - Check if constituent is neutral hadron + +""" + get_is_mu(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Check if each constituent particle is a muon. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +Vector of vectors of is muon boolean values as Float32. +""" +# Internal optimized function that computes all particle IDs in a single pass +function _get_particle_ids_optimized(jets_constituents::Vector{<:JetConstituents}) + n_jets = length(jets_constituents) + is_mu = Vector{Vector{Float32}}(undef, n_jets) + is_el = Vector{Vector{Float32}}(undef, n_jets) + is_charged_had = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + charges = jet_constituents.charge + masses = jet_constituents.mass + n_particles = length(charges) + + # Allocate once per jet + is_mu_jet = Vector{Float32}(undef, n_particles) + is_el_jet = Vector{Float32}(undef, n_particles) + is_ch_jet = Vector{Float32}(undef, n_particles) + + # Single pass through particles - better cache locality + @simd for i = 1:n_particles + charge = charges[i] + mass = masses[i] + abs_charge = abs(charge) + + # Compute all three in one go + charge_check = abs_charge > 0 + + # Check for muon + mass_check_mu = abs(mass - MUON_MASS) < MUON_TOLERANCE + is_mu_jet[i] = (charge_check & mass_check_mu) ? 1.0f0 : 0.0f0 + + # Check for electron + mass_check_el = abs(mass - ELECTRON_MASS) < ELECTRON_TOLERANCE + is_el_jet[i] = (charge_check & mass_check_el) ? 1.0f0 : 0.0f0 + + # Check for charged hadron (pion) + mass_check_had = abs(mass - PION_MASS) < PION_TOLERANCE + is_ch_jet[i] = (charge_check & mass_check_had) ? 1.0f0 : 0.0f0 + end + + is_mu[j] = is_mu_jet + is_el[j] = is_el_jet + is_charged_had[j] = is_ch_jet + end + + return (is_mu, is_el, is_charged_had) +end + +function get_is_mu(jets_constituents::Vector{<:JetConstituents}) + # Check if we can compute all three particle IDs together for better performance + # This is an internal optimization - the API remains the same + return _get_particle_ids_optimized(jets_constituents)[1] +end + +""" + get_is_el(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Check if each constituent particle is an electron. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +Vector of vectors of is electron boolean values as Float32. +""" +function get_is_el(jets_constituents::Vector{<:JetConstituents}) + # Check if we can compute all three particle IDs together for better performance + # This is an internal optimization - the API remains the same + return _get_particle_ids_optimized(jets_constituents)[2] +end + +""" + get_is_charged_had(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Check if each constituent particle is a charged hadron. + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +Vector of vectors of is charged hadron boolean values as Float32. +""" +function get_is_charged_had(jets_constituents::Vector{<:JetConstituents}) + # Check if we can compute all three particle IDs together for better performance + # This is an internal optimization - the API remains the same + return _get_particle_ids_optimized(jets_constituents)[3] +end + +""" + get_is_gamma(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Check if each constituent particle is a photon (gamma) (PDG 22). + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +Vector of vectors of is photon boolean values as Float32. +""" +function get_is_gamma(jets_constituents::Vector{<:JetConstituents}) + n_jets = length(jets_constituents) + result = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + types = jet_constituents.type + n_particles = length(types) + + is_gamma = Vector{Float32}(undef, n_particles) + + @simd for i = 1:n_particles + is_gamma[i] = types[i] == PDG_PHOTON ? 1.0f0 : 0.0f0 + end + + result[j] = is_gamma + end + + return result +end + +""" + get_is_neutral_had(jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Check if each constituent particle is a neutral hadron (PDG 130). + +# Arguments +- jets_constituents: Vector of jet constituents + +# Returns +Vector of vectors of is neutral hadron boolean values as Float32. +""" +function get_is_neutral_had(jets_constituents::Vector{<:JetConstituents}) + n_jets = length(jets_constituents) + result = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for j = 1:n_jets + jet_constituents = jets_constituents[j] + types = jet_constituents.type + n_particles = length(types) + + is_gamma = Vector{Float32}(undef, n_particles) + + @simd for i = 1:n_particles + is_gamma[i] = types[i] == PDG_K_LONG ? 1.0f0 : 0.0f0 + end + + result[j] = is_gamma + end + + return result +end + +### Relative Kinematics (5) + +# get_erel_cluster - Relative energy for clustered jets +# get_erel_log_cluster - Log of relative energy for clustered jets +# get_thetarel_cluster - Relative polar angle for clustered jets +# get_phirel_cluster - Relative azimuthal angle for clustered jets +# get_theta_phi_rel_cluster - Combined relative angles for clustered jets as they use the same logic + +""" + get_erel_cluster(jets::Vector{T} where T, + jets_constituents::Vector{<:JetConstituents}) -> Vector{JetConstituentsData} + +Calculate relative energy (E_const/E_jet) for each constituent particle in clustered jets. + +# Arguments +- `jets::Vector{T} where T`: Vector of clustered jets (any jet type). +- `jets_constituents::Vector{<:JetConstituents}`: Vector of jet constituents corresponding to the jets. + +# Returns +Vector containing relative energy values for each constituent in the jets. +""" +function get_erel_cluster( + jets::Vector{T} where {T}, + jets_constituents::Vector{<:JetConstituents}, +) + n_jets = length(jets) + res = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for i = 1:n_jets + e_jet = jets[i].E + constituents_collection = jets_constituents[i] + energies = constituents_collection.energy + n_constituents = length(energies) + jet_constituents_collection = Vector{Float32}(undef, n_constituents) + + if e_jet > 0.0f0 + inv_e_jet = 1.0f0 / e_jet + @inbounds @simd for j = 1:n_constituents + jet_constituents_collection[j] = energies[j] * inv_e_jet + end + else + @inbounds @simd for j = 1:n_constituents + jet_constituents_collection[j] = 0.0f0 + end + end + + res[i] = jet_constituents_collection + end + + return res +end + +""" + get_erel_log_cluster(jets::Vector{EEJet}, + jets_constituents::Vector{JetConstituents}) -> Vector{JetConstituentsData} + +Calculate log of relative energy (log(E_const/E_jet)) for each constituent particle in clustered jets. + +# Arguments +- `jets::Vector{T} where T`: Vector of clustered jets (any jet type) +- `jets_constituents::Vector{JetConstituents}`: Vector of jet constituents corresponding to the jets + +# Returns +Vector containing log of relative energy values for each constituent in the jets +""" +function get_erel_log_cluster( + jets::Vector{T} where {T}, + jets_constituents::Vector{<:JetConstituents}, +) + n_jets = length(jets) + res = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for i = 1:n_jets + e_jet = jets[i].E + constituents_collection = jets_constituents[i] + energies = constituents_collection.energy + n_constituents = length(energies) + jet_constituents_collection = Vector{Float32}(undef, n_constituents) + + if e_jet > 0.0f0 + inv_e_jet = 1.0f0 / e_jet + @inbounds @simd for j = 1:n_constituents + jet_constituents_collection[j] = log10(energies[j] * inv_e_jet) + end + else + @inbounds @simd for j = 1:n_constituents + jet_constituents_collection[j] = 0.0f0 + end + end + + res[i] = jet_constituents_collection + end + + return res +end + +""" + get_thetarel_cluster(jets::Vector{EEJet}, + jets_constituents::Vector{JetConstituents}) -> Vector{JetConstituentsData} + +Calculate relative theta angle between constituent particle and clustered jet axis. + +# Arguments +- `jets::Vector{T} where T`: Vector of clustered jets (any jet type) +- `jets_constituents::Vector{JetConstituents}`: Vector of jet constituents corresponding to the jets + +# Returns +Vector containing relative theta angle values for each constituent in the jets +""" +function get_thetarel_cluster( + jets::Vector{T} where {T}, + jets_constituents::Vector{<:JetConstituents}, +) + n_jets = length(jets) + result = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for i = 1:n_jets + jet = jets[i] + px, py, pz = jet.px, jet.py, jet.pz + + # Pre-compute jet angles + pt_jet = sqrt(px^2 + py^2) + theta_jet = atan(pt_jet, pz) + phi_jet = atan(py, px) + + # Pre-compute trig values using sincos + sin_phi, cos_phi = sincos(-phi_jet) + sin_theta, cos_theta = sincos(-theta_jet) + + constituents_collection = jets_constituents[i] + mom_x = constituents_collection.momentum.x + mom_y = constituents_collection.momentum.y + mom_z = constituents_collection.momentum.z + n_constituents = length(mom_x) + jet_constituents_collection = Vector{Float32}(undef, n_constituents) + + @inbounds for j = 1:n_constituents + # First rotation + p_rot_x = mom_x[j] * cos_phi - mom_y[j] * sin_phi + p_rot_y = mom_x[j] * sin_phi + mom_y[j] * cos_phi + + # Second rotation + p_rot2_x = p_rot_x * cos_theta - mom_z[j] * sin_theta + p_rot2_z = p_rot_x * sin_theta + mom_z[j] * cos_theta + + pt_rot_sq = p_rot2_x^2 + p_rot_y^2 + pt_rot = sqrt(pt_rot_sq) + jet_constituents_collection[j] = atan(pt_rot, p_rot2_z) + end + + result[i] = jet_constituents_collection + end + + return result +end + +""" + get_phirel_cluster(jets::Vector{EEJet}, + jets_constituents::Vector{JetConstituents}) -> Vector{JetConstituentsData} + +Calculate relative phi angle between constituent particle and clustered jet axis. + +# Arguments +- `jets::Vector{T} where T`: Vector of clustered jets (any jet type) +- `jets_constituents::Vector{JetConstituents}`: Vector of jet constituents corresponding to the jets + +# Returns +Vector containing relative phi angle values for each constituent in the jets +""" +function get_phirel_cluster( + jets::Vector{T} where {T}, + jets_constituents::Vector{<:JetConstituents}, +) + n_jets = length(jets) + result = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for i = 1:n_jets + jet = jets[i] + px, py, pz = jet.px, jet.py, jet.pz + + # Pre-compute jet angles + pt_jet = sqrt(px^2 + py^2) + theta_jet = atan(pt_jet, pz) + phi_jet = atan(py, px) + + # Pre-compute trig values using sincos + sin_phi, cos_phi = sincos(-phi_jet) + sin_theta, cos_theta = sincos(-theta_jet) + + constituents_collection = jets_constituents[i] + mom_x = constituents_collection.momentum.x + mom_y = constituents_collection.momentum.y + mom_z = constituents_collection.momentum.z + n_constituents = length(mom_x) + jet_constituents_collection = Vector{Float32}(undef, n_constituents) + + @inbounds for j = 1:n_constituents + # First rotation around z-axis by -phi_jet + p_rot_x = mom_x[j] * cos_phi - mom_y[j] * sin_phi + p_rot_y = mom_x[j] * sin_phi + mom_y[j] * cos_phi + + # Second rotation around y-axis by -theta_jet + p_rot2_x = p_rot_x * cos_theta - mom_z[j] * sin_theta + + # Calculate phi in rotated frame + jet_constituents_collection[j] = atan(p_rot_y, p_rot2_x) + end + + result[i] = jet_constituents_collection + end + + return result +end + +""" + get_thetaphirel_cluster(jets::Vector{EEJet}, + jets_constituents::Vector{JetConstituents}) -> Vector{JetConstituentsData} + +Calculate relative theta and phi angles between constituent particles and clustered jet axis. + +# Arguments +- `jets::Vector{T} where T`: Vector of clustered jets (any jet type) +- `jets_constituents::Vector{JetConstituents}`: Vector of jet constituents corresponding to the jets + +# Returns +Tuple of Vectors containing relative theta and phi angle values for each constituent in the jets +""" +function get_thetarel_phirel_cluster( + jets::Vector{T} where {T}, + jets_constituents::Vector{<:JetConstituents}, +) + n_jets = length(jets) + theta_result = Vector{Vector{Float32}}(undef, n_jets) + phi_result = Vector{Vector{Float32}}(undef, n_jets) + + @inbounds for i = 1:n_jets + jet = jets[i] + px, py, pz = jet.px, jet.py, jet.pz + + # Pre-compute jet angles + pt_jet = sqrt(px^2 + py^2) + theta_jet = atan(pt_jet, pz) + phi_jet = atan(py, px) + + # Pre-compute trig values using sincos + sin_phi, cos_phi = sincos(-phi_jet) + sin_theta, cos_theta = sincos(-theta_jet) + + constituents_collection = jets_constituents[i] + mom_x = constituents_collection.momentum.x + mom_y = constituents_collection.momentum.y + mom_z = constituents_collection.momentum.z + n_constituents = length(mom_x) + + jet_theta = Vector{Float32}(undef, n_constituents) + jet_phi = Vector{Float32}(undef, n_constituents) + + @inbounds for j = 1:n_constituents + # First rotation around z-axis by -phi_jet + p_rot_x = mom_x[j] * cos_phi - mom_y[j] * sin_phi + p_rot_y = mom_x[j] * sin_phi + mom_y[j] * cos_phi + + # Second rotation around y-axis by -theta_jet + p_rot2_x = p_rot_x * cos_theta - mom_z[j] * sin_theta + p_rot2_z = p_rot_x * sin_theta + mom_z[j] * cos_theta + p_rot2_y = p_rot_y + + # Calculate both theta and phi in rotated frame + pt_rot = sqrt(p_rot2_x^2 + p_rot2_y^2) + jet_theta[j] = atan(pt_rot, p_rot2_z) + jet_phi[j] = atan(p_rot2_y, p_rot2_x) + end + + theta_result[i] = jet_theta + phi_result[i] = jet_phi + end + + return (theta_result, phi_result) +end + +### Special Measurements (2) + +# get_dndx - dE/dx measurement (energy loss) +# get_mtof - Mass from time-of-flight measurement + +""" + get_dndx(jets_constituents::Vector{JetConstituents}, + dNdx::StructVector{EDM4hep.Quantity}, + trackdata::StructVector{EDM4hep.Track}, + JetsConstituents_isChargedHad::Vector{JetConstituentsData}) -> Vector{JetConstituentsData} + +Calculate dE/dx or dN/dx for each charged hadron in jets. Neutrals, muons, and electrons are set to 0. +Only charged hadrons have valid dN/dx values. + +# Arguments +- jets_constituents: Vector of jet constituents (each element contains particles for one jet) +- dNdx: StructVector of Quantity objects containing dN/dx measurements (EFlowTrack_2) +- trackdata: StructVector of Track objects (EFlowTrack) +- JetsConstituents_isChargedHad: Vector of vectors indicating which particles are charged hadrons + +# Returns +Vector of vectors of dN/dx values (one vector per jet, one value per constituent) +""" +function get_dndx( + jets_constituents::Vector{<:JetConstituents}, + dNdx::StructVector{EDM4hep.Quantity}, + trackdata::StructVector{EDM4hep.Track}, + JetsConstituents_isChargedHad::Vector{Vector{Float32}}, +) + n_jets = length(jets_constituents) + result = Vector{Vector{Float32}}(undef, n_jets) + tracks_len = length(trackdata) + + @inbounds for i = 1:n_jets + jet = jets_constituents[i] + tracks_first = jet.tracks + isChargedHad = JetsConstituents_isChargedHad[i] + n_constituents = length(jet) + tmp = Vector{Float32}(undef, n_constituents) + + @simd ivdep for j = 1:n_constituents + has_valid_track = tracks_first[j].first + 1 <= tracks_len + is_charged_had = isChargedHad[j] == 1.0f0 + tmp[j] = ifelse(has_valid_track & is_charged_had, -1.0f0, 0.0f0) + end + + result[i] = tmp + end + + return result +end + +function get_mtof( + jets_constituents::Vector{<:JetConstituents}, + track_L::AbstractArray{T} where {T<:Float32}, + trackdata::StructVector{EDM4hep.Track}, + trackerhits::StructVector{EDM4hep.TrackerHit}, + gammadata::StructVector{EDM4hep.Cluster}, + nhdata::StructVector{EDM4hep.Cluster}, + calohits::StructVector{EDM4hep.CalorimeterHit}, + V::LorentzVector, +) + n_jets = length(jets_constituents) + result = Vector{Vector{Float32}}(undef, n_jets) + + # Pre-compute limits + tracks_len = length(trackdata) + gamma_len = length(gammadata) + nh_len = length(nhdata) + cluster_limit = nh_len + gamma_len + + # Pre-compute vertex values + vx, vy, vz = V.x, V.y, V.z + v_t_scaled = V.t * MM_TO_M * C_LIGHT_INV # Tin calculation + + @inbounds for i = 1:n_jets + single_jet_constituents = jets_constituents[i] + n_constituents = length(single_jet_constituents) + + # Pre-allocate for this jet + tmp = Vector{Float32}(undef, n_constituents) + result_idx = 0 + + # Access fields once + clusters_first = single_jet_constituents.clusters + tracks_first = single_jet_constituents.tracks + types = single_jet_constituents.type + charges = single_jet_constituents.charge + masses = single_jet_constituents.mass + energies = single_jet_constituents.energy + mom_x = single_jet_constituents.momentum.x + mom_y = single_jet_constituents.momentum.y + mom_z = single_jet_constituents.momentum.z + + for j = 1:n_constituents + cluster_idx = clusters_first[j].first + track_idx = tracks_first[j].first + particle_type = types[j] + + mass_calculated = INVALID_MASS # Invalid marker + + # Handle cluster-based particles + if cluster_idx < cluster_limit + if particle_type == PDG_K_LONG # Neutral hadron + nh_idx = cluster_idx + 1 - gamma_len + hit_idx = nhdata[nh_idx].hits.first + 1 + + # Get hit data + hit = calohits[hit_idx] + tof = hit.time + + # Calculate distance + dx = hit.position.x - vx + dy = hit.position.y - vy + dz = hit.position.z - vz + l = sqrt(dx^2 + dy^2 + dz^2) * MM_TO_M + + beta = l / (tof * C_LIGHT) + + if 0.0f0 < beta < 1.0f0 + energy = energies[j] + mass_calculated = energy * sqrt(1.0f0 - beta^2) + else + mass_calculated = INVALID_TOF_MASS # Invalid measurement + end + + elseif particle_type == PDG_PHOTON # Photon + mass_calculated = 0.0f0 + end + end + + # Handle track-based particles (only if not already calculated) + if mass_calculated < 0.0f0 && track_idx < tracks_len + charge = charges[j] + if abs(charge) > 0.0f0 + mass = masses[j] + + # Check for known particles + if abs(mass - ELECTRON_MASS) < ELECTRON_TOLERANCE + mass_calculated = ELECTRON_MASS + elseif abs(mass - MUON_MASS) < MUON_TOLERANCE + mass_calculated = MUON_MASS + else + # Calculate from time of flight + track = trackdata[track_idx+1] + last_hit_idx = track.trackerHits.last + Tout = trackerhits[last_hit_idx].time + tof = Tout - v_t_scaled + + l = track_L[track_idx+1] * MM_TO_M + beta = l / (tof * C_LIGHT) + + if 0.0f0 < beta < 1.0f0 + # Calculate momentum magnitude + p = sqrt(mom_x[j]^2 + mom_y[j]^2 + mom_z[j]^2) + mass_calculated = p * sqrt(1.0f0 / (beta^2) - 1.0f0) + else + mass_calculated = PION_MASS # Default + end + end + end + end + + # Store result if we calculated a mass + if mass_calculated >= 0.0f0 + result_idx += 1 + tmp[result_idx] = mass_calculated + end + end + + # Resize to actual number of particles with calculated mass + result[i] = resize!(tmp, result_idx) + end + + return result +end + +### Impact Parameters and Jet Distance (12) + +## 2D Impact Parameter +# get_Sip2dVal_clusterV - Vectorized version 2D signed impact parameter value for clustered jets +# get_Sip2dSig - 2D impact parameter significance + +# 3D Impact Parameter +# get_Sip3dVal_clusterV - Vectorized version 3D signed impact parameter value for clustered jets +# get_Sip3dSig - 3D impact parameter significance + +# Jet Distance +# get_JetDistVal_clusterV - Vectorized version jet distance value for clustered jets +# get_JetDistSig - Jet distance significance + +""" + get_Sip2dVal_clusterV(jets::Vector{EEJet}, + D0::Vector{JetConstituentsData}, + phi0::Vector{JetConstituentsData}, + Bz::Float32) -> Vector{JetConstituentsData} + +Calculate the 2D signed impact parameter value for each particle relative to the jet axis. +This is a Julia implementation of the C++ function get_Sip2dVal_clusterV. + +# Arguments +- jets: Vector of EEJet objects representing jets +- D0: Vector of vectors containing D0 values (transverse impact parameters) +- phi0: Vector of vectors containing phi0 values (azimuthal angles at impact point) +- Bz: The magnetic field in Tesla + +# Returns +Vector of vectors of 2D signed impact parameter values (one vector per jet) +""" +function get_Sip2dVal_clusterV( + jets::Vector{T} where {T}, + D0::Vector{JetConstituentsData}, + phi0::Vector{JetConstituentsData}, + Bz::Float32, +) + n_jets = length(jets) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for i = 1:n_jets + px = Float32(jets[i].px) + py = Float32(jets[i].py) + d0_vals = D0[i] + phi_vals = phi0[i] + n_constituents = length(d0_vals) + + sip2d_values = Vector{Float32}(undef, n_constituents) + + @inbounds for j = 1:n_constituents + d0_val = d0_vals[j] + phi_val = phi_vals[j] + + sin_phi, cos_phi = sincos(phi_val) + + d0x = -d0_val * sin_phi + d0y = d0_val * cos_phi + + dot_product = d0x * px + d0y * py + + abs_d0 = abs(d0_val) + sign_dot = sign(dot_product) + signed_ip = sign_dot * abs_d0 + + is_valid = Float32(d0_val != UNDEF_VAL) + sip2d_values[j] = is_valid * signed_ip + (1.0f0 - is_valid) * UNDEF_VAL + end + + result[i] = sip2d_values + end + + return result +end + +""" + get_btagSip2dVal(jets::Vector{EEJet}, + pfcand_dxy::Vector{JetConstituentsData}, + pfcand_phi0::Vector{JetConstituentsData}, + Bz::Float32) -> Vector{JetConstituentsData} + +Call the implementation function get_Sip2dVal_clusterV +""" +function get_btagSip2dVal( + jets::Vector{T} where {T}, + pfcand_dxy::Vector{JetConstituentsData}, + pfcand_phi0::Vector{JetConstituentsData}, + Bz::Float32, +) + # Simply call the implementation function + return get_Sip2dVal_clusterV(jets, pfcand_dxy, pfcand_phi0, Bz) +end + +""" + get_Sip2dSig(Sip2dVals::Vector{JetConstituentsData}, + err2_D0::Vector{JetConstituentsData}) -> Vector{JetConstituentsData} + +Calculate the 2D signed impact parameter significance for each particle. +This is a Julia implementation of the C++ function get_Sip2dSig. + +# Arguments +- Sip2dVals: Vector of vectors containing 2D signed impact parameter values +- err2_D0: Vector of vectors containing squared errors of the D0 values + +# Returns +Vector of vectors of 2D signed impact parameter significances (one vector per jet) +""" +function get_Sip2dSig( + Sip2dVals::Vector{JetConstituentsData}, + err2_D0::Vector{JetConstituentsData}, +) + n_jets = length(Sip2dVals) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for i = 1:n_jets + n_constituents = length(Sip2dVals[i]) + sig_values = Vector{Float32}(undef, n_constituents) + sip_vals = Sip2dVals[i] + err_vals = err2_D0[i] + + @simd for j = 1:n_constituents + err_val = err_vals[j] + sip_val = sip_vals[j] + + valid = err_val > 0.0f0 + sqrt_err = sqrt(max(err_val, eps(Float32))) # Avoid sqrt of negative + sig = sip_val / sqrt_err + + sig_values[j] = valid ? sig : UNDEF_VAL + end + + result[i] = sig_values + end + + return result +end + +""" + get_btagSip2dSig(pfcand_btagSip2dVal::Vector{JetConstituentsData}, + pfcand_dxydxy::Vector{JetConstituentsData}) -> Vector{JetConstituentsData} + +Call the implementation function get_Sip2dSig +""" +function get_btagSip2dSig( + pfcand_btagSip2dVal::Vector{JetConstituentsData}, + pfcand_dxydxy::Vector{JetConstituentsData}, +) + # Simply call the implementation function + return get_Sip2dSig(pfcand_btagSip2dVal, pfcand_dxydxy) +end + +""" + get_Sip3dVal_clusterV(jets::Vector{EEJet}, + D0::Vector{JetConstituentsData}, + Z0::Vector{JetConstituentsData}, + phi0::Vector{JetConstituentsData}, + Bz::Float32) -> Vector{JetConstituentsData} + +Calculate the 3D signed impact parameter value for each particle relative to the jet axis. +""" +function get_Sip3dVal_clusterV( + jets::Vector{T} where {T}, + D0::Vector{JetConstituentsData}, + Z0::Vector{JetConstituentsData}, + phi0::Vector{JetConstituentsData}, + Bz::Float32, +) + n_jets = length(jets) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for i = 1:n_jets + px = Float32(jets[i].px) + py = Float32(jets[i].py) + pz = Float32(jets[i].pz) + d0_vals = D0[i] + z0_vals = Z0[i] + phi_vals = phi0[i] + n_constituents = length(d0_vals) + + cprojs = Vector{Float32}(undef, n_constituents) + + @inbounds for j = 1:n_constituents + d0_val = d0_vals[j] + z0_val = z0_vals[j] + phi_val = phi_vals[j] + + sin_phi, cos_phi = sincos(phi_val) + + dx = -d0_val * sin_phi + dy = d0_val * cos_phi + dz = z0_val + + dot_prod = dx * px + dy * py + dz * pz + + magnitude = sqrt(d0_val * d0_val + z0_val * z0_val) + sign_dot = sign(dot_prod) + signed_ip = sign_dot * magnitude + + is_valid = Float32(d0_val != UNDEF_VAL) + cprojs[j] = is_valid * signed_ip + (1.0f0 - is_valid) * UNDEF_VAL + end + + result[i] = cprojs + end + + return result +end + +function get_btagSip3dVal( + jets::Vector{T} where {T}, + pfcand_dxy::Vector{JetConstituentsData}, + pfcand_dz::Vector{JetConstituentsData}, + pfcand_phi0::Vector{JetConstituentsData}, + Bz::Float32, +) + # Simply call the implementation function + return get_Sip3dVal_clusterV(jets, pfcand_dxy, pfcand_dz, pfcand_phi0, Bz) +end + +""" + get_Sip3dSig(Sip3dVals::Vector{JetConstituentsData}, + err2_D0::Vector{JetConstituentsData}, + err2_Z0::Vector{JetConstituentsData}) -> Vector{JetConstituentsData} + +Calculate the 3D signed impact parameter significance (value/error) for each particle. +""" +function get_Sip3dSig( + Sip3dVals::Vector{JetConstituentsData}, + err2_D0::Vector{JetConstituentsData}, + err2_Z0::Vector{JetConstituentsData}, +) + n_jets = length(Sip3dVals) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for i = 1:n_jets + n_constituents = length(Sip3dVals[i]) + sigs = Vector{Float32}(undef, n_constituents) + sip_vals = Sip3dVals[i] + err_d0 = err2_D0[i] + err_z0 = err2_Z0[i] + + @simd for j = 1:n_constituents + err_d0_val = err_d0[j] + err_z0_val = err_z0[j] + sip_val = sip_vals[j] + + # Branchless computation + valid = err_d0_val > 0.0f0 + err_sum = err_d0_val + err_z0_val + sqrt_err = sqrt(max(err_sum, eps(Float32))) + sig = sip_val / sqrt_err + + sigs[j] = valid ? sig : UNDEF_VAL + end + + result[i] = sigs + end + + return result +end + +function get_btagSip3dSig( + Sip3dVals::Vector{JetConstituentsData}, + err2_D0::Vector{JetConstituentsData}, + err2_Z0::Vector{JetConstituentsData}, +) + # Simply call the implementation function + return get_Sip3dSig(Sip3dVals, err2_D0, err2_Z0) +end + +""" + get_JetDistVal_clusterV(jets::Vector{EEJet}, + jets_constituents::Vector{<:JetConstituents}, + D0::Vector{JetConstituentsData}, + Z0::Vector{JetConstituentsData}, + phi0::Vector{JetConstituentsData}, + Bz::Float32) -> Vector{JetConstituentsData} + +Calculate the jet distance value for each particle, measuring the distance between +the point of closest approach and the jet axis. +""" +function get_JetDistVal_clusterV( + jets::Vector{T} where {T}, + jets_constituents::Vector{<:JetConstituents}, + D0::Vector{JetConstituentsData}, + Z0::Vector{JetConstituentsData}, + phi0::Vector{JetConstituentsData}, + Bz::Float32, +) + n_jets = length(jets) + result = Vector{JetConstituentsData}(undef, n_jets) + + for i = 1:n_jets + px_jet, py_jet, pz_jet = jets[i].px, jets[i].py, jets[i].pz + single_jet_constituents = jets_constituents[i] + n_constituents = length(single_jet_constituents) + tmp = Vector{Float32}(undef, n_constituents) + + for j = 1:n_constituents + if D0[i][j] != UNDEF_VAL + d0_val = D0[i][j] + z0_val = Z0[i][j] + + # Use sincos for efficiency + sin_phi, cos_phi = sincos(phi0[i][j]) + + # Impact parameter vector + dx = -d0_val * sin_phi + dy = d0_val * cos_phi + dz = z0_val + + # Constituent momentum + px_ct = single_jet_constituents[j].momentum.x + py_ct = single_jet_constituents[j].momentum.y + pz_ct = single_jet_constituents[j].momentum.z + + # Cross product: n = p_ct × p_jet + nx = py_ct * pz_jet - pz_ct * py_jet + ny = pz_ct * px_jet - px_ct * pz_jet + nz = px_ct * py_jet - py_ct * px_jet + + # Normalize + n_mag = sqrt(nx^2 + ny^2 + nz^2) + inv_n_mag = 1.0f0 / max(n_mag, eps(Float32)) + nx *= inv_n_mag + ny *= inv_n_mag + nz *= inv_n_mag + + # Distance (r_jet = [0,0,0], so we just need n·d) + tmp[j] = nx * dx + ny * dy + nz * dz + else + tmp[j] = UNDEF_VAL + end + end + + result[i] = tmp + end + + return result +end + +function get_btagJetDistVal( + jets::Vector{T} where {T}, + jets_constituents::Vector{<:JetConstituents}, + D0::Vector{JetConstituentsData}, + Z0::Vector{JetConstituentsData}, + phi0::Vector{JetConstituentsData}, + Bz::Float32, +) + # Simply call the implementation function + return get_JetDistVal_clusterV(jets, jets_constituents, D0, Z0, phi0, Bz) +end + +""" + get_JetDistSig(JetDistVal::Vector{JetConstituentsData}, + err2_D0::Vector{JetConstituentsData}, + err2_Z0::Vector{JetConstituentsData}) -> Vector{JetConstituentsData} + +Calculate the jet distance significance (value/error) for each particle. +""" +function get_JetDistSig( + JetDistVal::Vector{JetConstituentsData}, + err2_D0::Vector{JetConstituentsData}, + err2_Z0::Vector{JetConstituentsData}, +) + n_jets = length(JetDistVal) + result = Vector{JetConstituentsData}(undef, n_jets) + + @inbounds for i = 1:n_jets + n_constituents = length(JetDistVal[i]) + tmp = Vector{Float32}(undef, n_constituents) + jet_vals = JetDistVal[i] + err_d0 = err2_D0[i] + err_z0 = err2_Z0[i] + + @simd for j = 1:n_constituents + err_d0_val = err_d0[j] + err_z0_val = err_z0[j] + jet_val = jet_vals[j] + + # Branchless computation + valid = err_d0_val > 0.0f0 + err_sum = err_d0_val + err_z0_val + sqrt_err = sqrt(max(err_sum, eps(Float32))) + sig = jet_val / sqrt_err + + tmp[j] = valid ? sig : UNDEF_VAL + end + + result[i] = tmp + end + + return result +end + +function get_btagJetDistSig( + JetDistVal::Vector{JetConstituentsData}, + err2_D0::Vector{JetConstituentsData}, + err2_Z0::Vector{JetConstituentsData}, +) + # Simply call the implementation function + return get_JetDistSig(JetDistVal, err2_D0, err2_Z0) +end + + +end # module JetConstituentUtils diff --git a/src/JetFlavourHelper.jl b/src/JetFlavourHelper.jl new file mode 100644 index 0000000..f227a4c --- /dev/null +++ b/src/JetFlavourHelper.jl @@ -0,0 +1,811 @@ +# JetFlavourHelper functions +# EEJet is imported in the parent module JetTaggingFCC + +""" + JetFlavourHelper + +A module for jet flavour identification using neural networks. +""" + +""" + setup_onnx_runtime(onnx_path::AbstractString, json_path::AbstractString) -> ONNXRunTime.InferenceSession + +Setup the ONNX model and preprocessing configuration for jet flavour tagging. + +# Arguments +- `onnx_path`: Path to the ONNX model file +- `json_path`: Path to the JSON configuration file + +# Returns +An ONNX inference session for the loaded model +""" +function setup_onnx_runtime(onnx_path::AbstractString, json_path::AbstractString) + # Load JSON configuration + config = JSON.parsefile(json_path) + model = ONNXRunTime.load_inference(onnx_path) + + return model, config +end + +""" + normalize_feature(value::Float32, info::Dict) -> Float32 + +Normalize a feature value based on the preprocessing information. + +# Arguments +- `value`: Raw feature value +- `info`: Dictionary containing normalization parameters + +# Returns +Normalized feature value +""" +function normalize_feature(value::Float32, info::Dict) + if value == -9.0f0 + return 0.0f0 # Replace -9.0 (missing value) with 0 + end + + # Apply normalization using median and norm_factor + normalized = (value - info["median"]) * info["norm_factor"] + + # Clamp to specified bounds + return clamp(normalized, info["lower_bound"], info["upper_bound"]) +end + +""" + prepare_input_tensor(jets_constituents::Vector{StructVector{EDM4hep.ReconstructedParticle}}, + jets::Vector{EEJet}, + config::Dict, + feature_data::Dict, + jet_index::Int=1) -> Dict{String, Array} + +Prepare input tensors for the neural network from jet constituents. + +# Arguments +- `jets_constituents`: Vector of jet constituents (structured as a vector of StructVector of ReconstructedParticle) +- `jets`: Vector of jets (EEJet) +- `config`: JSON configuration for preprocessing +- `feature_data`: Dictionary containing all extracted features +- `jet_index`: Index of the jet to process (default 1) + +# Returns +Dictionary of input tensors +""" +function prepare_input_tensor( + jets_constituents::Vector{<:JetConstituents}, + jets::Vector{EEJet}, + config::Dict, + feature_data::Dict, + jet_index::Int = 1, +) + + # Get input names and variable info + input_names = config["input_names"] + + # Initialize input tensor dictionary + input_tensors = Dict{String,Array{Float32}}() + + # Get max length for padding + max_length = config["pf_points"]["var_length"] + + # Initialize tensors for single jet processing + for input_name in input_names + if input_name == "pf_features" + feature_vars = length(config[input_name]["var_names"]) + input_tensors[input_name] = zeros(Float32, 1, feature_vars, max_length) + elseif input_name == "pf_vectors" + vector_vars = length(config[input_name]["var_names"]) + input_tensors[input_name] = zeros(Float32, 1, vector_vars, max_length) + elseif input_name == "pf_mask" + input_tensors[input_name] = zeros(Float32, 1, 1, max_length) + end + end + + # Note: This function prepares tensors for a single jet at a time + # The caller can loop through jets and process them individually + # A batch processing version is planned for next version under the inference() function. + + # Process the specified jet + if length(jets) >= jet_index && jet_index > 0 + i = jet_index + jet = jets[1] # We're processing single jets, so always use first element + + # Fill each tensor for this jet + constituents = jets_constituents[1] # Single jet passed + num_constituents = min(length(constituents), max_length) + + # Fill mask (1 for valid constituents, 0 for padding) + if haskey(feature_data, "pf_mask") + for j = 1:num_constituents + input_tensors["pf_mask"][1, 1, j] = 1.0f0 + end + end + + # Fill points + if haskey(feature_data, "pf_points") && haskey(input_tensors, "pf_points") + for (var_idx, var_name) in enumerate(config["pf_points"]["var_names"]) + var_info = config["pf_points"]["var_infos"][var_name] + + for j = 1:num_constituents + if j <= length(feature_data["pf_points"][var_name][i]) + raw_value = feature_data["pf_points"][var_name][i][j] + norm_value = normalize_feature(raw_value, var_info) + input_tensors["pf_points"][1, var_idx, j] = norm_value + end + end + end + end + + # Fill features + if haskey(feature_data, "pf_features") && haskey(input_tensors, "pf_features") + for (var_idx, var_name) in enumerate(config["pf_features"]["var_names"]) + var_info = config["pf_features"]["var_infos"][var_name] + + for j = 1:num_constituents + if haskey(feature_data["pf_features"], var_name) && + j <= length(feature_data["pf_features"][var_name][i]) + raw_value = feature_data["pf_features"][var_name][i][j] + norm_value = normalize_feature(raw_value, var_info) + input_tensors["pf_features"][1, var_idx, j] = norm_value + end + end + end + end + + # Fill vectors (energies, momenta) + if haskey(feature_data, "pf_vectors") && haskey(input_tensors, "pf_vectors") + for (var_idx, var_name) in enumerate(config["pf_vectors"]["var_names"]) + for j = 1:num_constituents + if haskey(feature_data["pf_vectors"], var_name) && + j <= length(feature_data["pf_vectors"][var_name][i]) + input_tensors["pf_vectors"][1, var_idx, j] = + feature_data["pf_vectors"][var_name][i][j] + end + end + end + end + end + + return input_tensors +end + +""" + get_weights(slot::Int, vars::Dict{String, Vector{Vector{Float32}}}, + jets::Vector{EEJet}, json_config::Dict, model::ONNXRunTime.InferenceSession) -> Vector{Vector{Float32}} + +Compute jet flavour probabilities for each jet. + +# Arguments +- `slot`: Threading slot +- `vars`: Dictionary containing all features for jet constituents +- `jets`: Vector of jets +- `json_config`: JSON configuration for preprocessing +- `model`: ONNX inference session + +# Returns +Vector of flavour probabilities for each jet +""" +function get_weights( + slot::Int, + vars::Dict{String,Dict{String,Vector{Vector{Float32}}}}, + jets::Vector{EEJet}, + jets_constituents::Vector{<:JetConstituents}, + json_config::Dict, + model::ONNXRunTime.InferenceSession, +) + + # The model processes one jet at a time + result = Vector{Vector{Float32}}() + + # Process each jet individually + for i = 1:length(jets) + # Create single-jet arrays + single_jet = [jets[i]] + single_constituents = [jets_constituents[i]] + + # Create single-jet feature data by extracting only features for this jet + single_jet_vars = Dict{String,Dict{String,Vector{Vector{Float32}}}}() + for (category, features) in vars + single_jet_vars[category] = Dict{String,Vector{Vector{Float32}}}() + for (fname, fvalues) in features + # Extract only the features for jet i + if i <= length(fvalues) + single_jet_vars[category][fname] = [fvalues[i]] + else + # If no features for this jet, create empty array + single_jet_vars[category][fname] = [Float32[]] + end + end + end + + # Prepare input tensor for this single jet with extracted features + input_tensors = prepare_input_tensor( + single_constituents, + single_jet, + json_config, + single_jet_vars, + 1, + ) + + # Run inference + output = model(input_tensors) + + # Extract probabilities + probabilities = output["softmax"] + + # Get probabilities for this jet + num_classes = size(probabilities, 2) + jet_probs = Vector{Float32}(undef, num_classes) + for c = 1:num_classes + jet_probs[c] = probabilities[1, c] # Always index 1 since we process one jet at a time + end + push!(result, jet_probs) + end + + return result +end + +""" + get_weight(jet_weights::Vector{Vector{Float32}}, weight_idx::Int) -> Vector{Float32} + +Extract a specific weight/score from the jet weights. + +# Arguments +- `jet_weights`: Vector of weight vectors for each jet +- `weight_idx`: Index of the weight to extract + +# Returns +Vector of the specified weight for each jet +""" +function get_weight(jet_weights::Vector{Vector{Float32}}, weight_idx::Int) + if weight_idx < 0 + error("Invalid index requested for jet flavour weight.") + end + + result = Vector{Float32}() + + for jet_weight in jet_weights + if weight_idx >= length(jet_weight) + error("Flavour weight index exceeds the number of weights registered.") + end + + push!(result, jet_weight[weight_idx+1]) # +1 for Julia's 1-based indexing + end + + return result +end + +""" + inference(json_config_path::AbstractString, onnx_model_path::AbstractString, df::DataFrame, + jets::Vector{EEJet}, jets_constituents::Vector{StructVector{EDM4hep.ReconstructedParticle}}, + feature_data::Dict) -> DataFrame + +Run flavour tagging inference on a collection of jets. + +# Arguments +- `json_config_path`: Path to the JSON configuration file +- `onnx_model_path`: Path to the ONNX model file +- `jets`: Vector of jets +- `jets_constituents`: Vector of jet constituents +- `feature_data`: Dictionary containing all extracted features + +# Returns +DataFrame with added flavour tagging scores +""" +function inference( + json_config_path::AbstractString, + onnx_model_path::AbstractString, + jets::Vector{EEJet}, + jets_constituents::Vector{StructVector{EDM4hep.ReconstructedParticle}}, + feature_data::Dict, +) + + # Extract input variables/score names from JSON file + initvars = String[] + variables = String[] + scores = String[] + + config = JSON.parsefile(json_config_path) + + # Extract feature names + for varname in config["pf_features"]["var_names"] + push!(initvars, varname) + push!(variables, varname) + end + + # Extract vector names + for varname in config["pf_vectors"]["var_names"] + push!(initvars, varname) + push!(variables, varname) + end + + # Extract output names + for scorename in config["output_names"] + push!(scores, scorename) + end + + # Setup model + model, _ = setup_onnx_runtime(onnx_model_path, json_config_path) + + # Run inference + weights = get_weights(0, feature_data, jets, jets_constituents, config, model) + + # Extract individual scores + jet_scores = Dict{String,Vector{Float32}}() + + for (i, scorename) in enumerate(scores) + jet_scores[scorename] = get_weight(weights, i - 1) # Adjust for 0-based indexing in get_weight + end + + return jet_scores +end + +# TODO: Add primary vertex as an argument (from MC Particle) +""" + extract_features(jets::Vector{EEJet}, jets_constituents::Vector{<:JetConstituents}, + tracks::AbstractVector{EDM4hep.TrackState}, bz::Float32, + track_L::AbstractArray{T} where T <: AbstractFloat, + json_config::Dict, + trackdata::AbstractVector{EDM4hep.Track}=AbstractVector{EDM4hep.Track}(), + trackerhits::AbstractVector{EDM4hep.TrackerHit}=AbstractVector{EDM4hep.TrackerHit}(), + gammadata::AbstractVector{EDM4hep.Cluster}=AbstractVector{EDM4hep.Cluster}(), + nhdata::AbstractVector{EDM4hep.Cluster}=AbstractVector{EDM4hep.Cluster}(), + calohits::AbstractVector{EDM4hep.CalorimeterHit}=AbstractVector{EDM4hep.CalorimeterHit}(), + dNdx::AbstractVector{EDM4hep.Quantity}=AbstractVector{EDM4hep.Quantity}()) -> Dict + +Extract features for jet flavour tagging based on JSON configuration. + +# Arguments +- `jets`: Vector of jets (EEJet) +- `jets_constituents`: Vector of jet constituents +- `tracks`: StructVector of track states +- `bz`: Magnetic field strength +- `track_L`: Array of track lengths +- `json_config`: JSON configuration dict specifying which features to extract (required) +- `trackdata`: Vector of track data (optional) +- `trackerhits`: Vector of tracker hits (optional) +- `gammadata`: Vector of gamma clusters (optional) +- `nhdata`: Vector of neutral hadron clusters (optional) +- `calohits`: Vector of calorimeter hits (optional) +- `dNdx`: Vector of dE/dx measurements (optional) +- `mc_vertices`: Vector of MC vertices for each reconstructed particle (optional) + +# Returns +Dictionary containing extracted features as specified in the JSON configuration. +""" +function extract_features( + jets::Vector{EEJet}, + jets_constituents::Vector{<:JetConstituents}, + tracks::AbstractVector{EDM4hep.TrackState}, + bz::Float32, + track_L::AbstractArray{T} where {T<:AbstractFloat}, + json_config::Dict, + trackdata::AbstractVector{EDM4hep.Track} = AbstractVector{EDM4hep.Track}(), + trackerhits::AbstractVector{EDM4hep.TrackerHit} = AbstractVector{EDM4hep.TrackerHit}(), + gammadata::AbstractVector{EDM4hep.Cluster} = AbstractVector{EDM4hep.Cluster}(), + nhdata::AbstractVector{EDM4hep.Cluster} = AbstractVector{EDM4hep.Cluster}(), + calohits::AbstractVector{EDM4hep.CalorimeterHit} = AbstractVector{ + EDM4hep.CalorimeterHit, + }(), + dNdx::AbstractVector{EDM4hep.Quantity} = AbstractVector{EDM4hep.Quantity}(), + mc_vertices::Union{Nothing,Vector{LorentzVector{Float32}}} = nothing, +) + + # Primary vertex for displacement calculations + # Use provided MC vertices or default to (0,0,0,0) + # If mc_vertices are provided, find the most common vertex (primary vertex) + if isnothing(mc_vertices) || isempty(mc_vertices) + v_in = LorentzVector(0.0, 0.0, 0.0, 0.0) + else + # Find the most common vertex (likely the primary vertex) + # For simplicity, use the first non-zero vertex found + v_in = LorentzVector(0.0, 0.0, 0.0, 0.0) + for vertex in mc_vertices + if vertex.x != 0.0 || vertex.y != 0.0 || vertex.z != 0.0 + v_in = vertex + break + end + end + end + + # Initialize feature containers + features = Dict{String,Dict{String,Vector{Vector{Float32}}}}() + + # Cache for computed features to avoid redundant calculations + computed_features = Dict{String,Vector{Vector{Float32}}}() + + # Create kwargs dict with all available data + kwargs = Dict( + :jets => jets, + :tracks => tracks, + :v_in => v_in, + :bz => bz, + :track_L => track_L, + :trackdata => trackdata, + :trackerhits => trackerhits, + :gammadata => gammadata, + :nhdata => nhdata, + :calohits => calohits, + :dNdx => dNdx, + ) + + # Process each input type specified in the JSON + for input_name in json_config["input_names"] + if !haskey(json_config, input_name) || !haskey(json_config[input_name], "var_names") + continue + end + + features[input_name] = Dict{String,Vector{Vector{Float32}}}() + + # Extract each requested variable + for var_name in json_config[input_name]["var_names"] + # Remove "pfcand_" prefix if present + clean_name = replace(var_name, "pfcand_" => "") + + # Check if already computed + if haskey(computed_features, clean_name) + features[input_name][var_name] = computed_features[clean_name] + continue + end + + # Get the extractor function + extractor = get_feature_extractor(clean_name) + if extractor === nothing + # Feature not implemented, create empty arrays + features[input_name][var_name] = [Float32[] for _ = 1:length(jets)] + continue + end + + # Handle dependencies for certain features + deps = get_feature_dependencies(clean_name) + if !isempty(deps) + dep_kwargs = Dict{Symbol,Any}() + + for (dep_key, dep_name) in deps + if !haskey(computed_features, dep_name) + # Recursively compute dependency + dep_extractor = get_feature_extractor(dep_name) + if dep_extractor !== nothing + # Check if this dependency has its own dependencies + sub_deps = get_feature_dependencies(dep_name) + if !isempty(sub_deps) + sub_dep_kwargs = Dict{Symbol,Any}() + for (sub_dep_key, sub_dep_name) in sub_deps + if !haskey(computed_features, sub_dep_name) + sub_dep_extractor = + get_feature_extractor(sub_dep_name) + if sub_dep_extractor !== nothing + computed_features[sub_dep_name] = + sub_dep_extractor( + jets_constituents; + kwargs..., + ) + end + end + if haskey(computed_features, sub_dep_name) + sub_dep_kwargs[sub_dep_key] = + computed_features[sub_dep_name] + end + end + # Compute dependency with its dependencies + temp_kwargs = merge(kwargs, sub_dep_kwargs) + computed_features[dep_name] = + dep_extractor(jets_constituents; temp_kwargs...) + else + # No sub-dependencies, compute directly + computed_features[dep_name] = + dep_extractor(jets_constituents; kwargs...) + end + end + end + if haskey(computed_features, dep_name) + dep_kwargs[dep_key] = computed_features[dep_name] + end + end + + # Merge dependency kwargs with main kwargs + merge!(kwargs, dep_kwargs) + end + + # Special handling for features that need isChargedHad + if clean_name == "dndx" && !haskey(kwargs, :jets_constituents_isChargedHad) + if !haskey(computed_features, "isChargedHad") + computed_features["isChargedHad"] = + JetConstituentUtils.get_is_charged_had(jets_constituents) + end + kwargs[:jets_constituents_isChargedHad] = computed_features["isChargedHad"] + end + + # Extract the feature + try + feature_data = extractor(jets_constituents; kwargs...) + computed_features[clean_name] = feature_data + features[input_name][var_name] = feature_data + catch e + # If extraction fails, create empty arrays + @warn "Failed to extract feature $var_name: $e" + features[input_name][var_name] = [Float32[] for _ = 1:length(jets)] + end + end + end + + return features +end + +""" + get_feature_extractor(feature_name::String) + +Return the appropriate feature extraction function for a given feature name. + +# Arguments +- `feature_name`: Name of the feature (without "pfcand_" prefix) + +# Returns +A function that can extract the specified feature, or nothing if not found +""" +function get_feature_extractor(feature_name::String) + # Map feature names to extraction functions + feature_map = Dict( + # Basic kinematic features + "e" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_e(jets_constituents), + "p" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_p(jets_constituents), + "theta" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_theta(jets_constituents), + "phi" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_phi(jets_constituents), + "charge" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_charge(jets_constituents), + "type" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_type(jets_constituents), + + # Relative kinematic features + "erel_log" => + (jets_constituents; jets = nothing, kwargs...) -> + JetConstituentUtils.get_erel_log_cluster(jets, jets_constituents), + "thetarel" => + (jets_constituents; jets = nothing, kwargs...) -> + JetConstituentUtils.get_thetarel_cluster(jets, jets_constituents), + "phirel" => + (jets_constituents; jets = nothing, kwargs...) -> + JetConstituentUtils.get_phirel_cluster(jets, jets_constituents), + + # Particle ID features + "isMu" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_is_mu(jets_constituents), + "isEl" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_is_el(jets_constituents), + "isChargedHad" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_is_charged_had(jets_constituents), + "isGamma" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_is_gamma(jets_constituents), + "isNeutralHad" => + (jets_constituents; kwargs...) -> + JetConstituentUtils.get_is_neutral_had(jets_constituents), + + # Track parameters + "dxy" => + ( + jets_constituents; + tracks = nothing, + v_in = nothing, + bz = nothing, + kwargs..., + ) -> JetConstituentUtils.get_dxy(jets_constituents, tracks, v_in, bz), + "dz" => + ( + jets_constituents; + tracks = nothing, + v_in = nothing, + bz = nothing, + kwargs..., + ) -> JetConstituentUtils.get_dz(jets_constituents, tracks, v_in, bz), + "phi0" => + ( + jets_constituents; + tracks = nothing, + v_in = nothing, + bz = nothing, + kwargs..., + ) -> JetConstituentUtils.get_phi0(jets_constituents, tracks, v_in, bz), + + # Covariance matrix elements + "dptdpt" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dptdpt(jets_constituents, tracks), + "detadeta" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_detadeta(jets_constituents, tracks), + "dphidphi" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dphidphi(jets_constituents, tracks), + "dxydxy" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dxydxy(jets_constituents, tracks), + "dzdz" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dzdz(jets_constituents, tracks), + "dxydz" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dxydz(jets_constituents, tracks), + "dphidxy" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dphidxy(jets_constituents, tracks), + "dlambdadz" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dlambdadz(jets_constituents, tracks), + "dxyc" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dxyc(jets_constituents, tracks), + "dxyctgtheta" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_dxyctgtheta(jets_constituents, tracks), + "phic" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_phic(jets_constituents, tracks), + "phidz" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_phidz(jets_constituents, tracks), + "phictgtheta" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_phictgtheta(jets_constituents, tracks), + "cdz" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_cdz(jets_constituents, tracks), + "cctgtheta" => + (jets_constituents; tracks = nothing, kwargs...) -> + JetConstituentUtils.get_cctgtheta(jets_constituents, tracks), + + # B-tagging features + "btagSip2dVal" => + ( + jets_constituents; + jets = nothing, + dxy = nothing, + phi0 = nothing, + bz = nothing, + kwargs..., + ) -> JetConstituentUtils.get_btagSip2dVal(jets, dxy, phi0, bz), + "btagSip2dSig" => + (jets_constituents; btagSip2dVal = nothing, dxydxy = nothing, kwargs...) -> + JetConstituentUtils.get_btagSip2dSig(btagSip2dVal, dxydxy), + "btagSip3dVal" => + ( + jets_constituents; + jets = nothing, + dxy = nothing, + dz = nothing, + phi0 = nothing, + bz = nothing, + kwargs..., + ) -> JetConstituentUtils.get_btagSip3dVal(jets, dxy, dz, phi0, bz), + "btagSip3dSig" => + ( + jets_constituents; + btagSip3dVal = nothing, + dxydxy = nothing, + dzdz = nothing, + kwargs..., + ) -> JetConstituentUtils.get_btagSip3dSig(btagSip3dVal, dxydxy, dzdz), + "btagJetDistVal" => + ( + jets_constituents; + jets = nothing, + dxy = nothing, + dz = nothing, + phi0 = nothing, + bz = nothing, + kwargs..., + ) -> JetConstituentUtils.get_btagJetDistVal( + jets, + jets_constituents, + dxy, + dz, + phi0, + bz, + ), + "btagJetDistSig" => + ( + jets_constituents; + btagJetDistVal = nothing, + dxydxy = nothing, + dzdz = nothing, + kwargs..., + ) -> JetConstituentUtils.get_btagJetDistSig(btagJetDistVal, dxydxy, dzdz), + + # Special features + "mtof" => + ( + jets_constituents; + track_L = nothing, + trackdata = nothing, + trackerhits = nothing, + gammadata = nothing, + nhdata = nothing, + calohits = nothing, + v_in = nothing, + kwargs..., + ) -> JetConstituentUtils.get_mtof( + jets_constituents, + track_L, + trackdata, + trackerhits, + gammadata, + nhdata, + calohits, + v_in, + ), + "dndx" => + ( + jets_constituents; + dNdx = nothing, + trackdata = nothing, + jets_constituents_isChargedHad = nothing, + kwargs..., + ) -> JetConstituentUtils.get_dndx( + jets_constituents, + dNdx, + trackdata, + jets_constituents_isChargedHad, + ), + + # Mask + "mask" => + (jets_constituents; kwargs...) -> [ + fill(1.0f0, length(constituents)) for constituents in jets_constituents + ], + ) + + return get(feature_map, feature_name, nothing) +end + +""" + get_feature_dependencies(feature_name::String) -> Dict{Symbol, String} + +Get the dependencies for features that require other computed features. + +# Arguments +- `feature_name`: Name of the feature + +# Returns +Dictionary mapping parameter names to feature names +""" +function get_feature_dependencies(feature_name::String) + deps = Dict{Symbol,String}() + + if feature_name == "btagSip2dSig" + deps[:btagSip2dVal] = "btagSip2dVal" + deps[:dxydxy] = "dxydxy" + elseif feature_name == "btagSip3dSig" + deps[:btagSip3dVal] = "btagSip3dVal" + deps[:dxydxy] = "dxydxy" + deps[:dzdz] = "dzdz" + elseif feature_name == "btagJetDistSig" + deps[:btagJetDistVal] = "btagJetDistVal" + deps[:dxydxy] = "dxydxy" + deps[:dzdz] = "dzdz" + elseif feature_name == "btagSip2dVal" + deps[:dxy] = "dxy" + deps[:phi0] = "phi0" + elseif feature_name == "btagSip3dVal" + deps[:dxy] = "dxy" + deps[:dz] = "dz" + deps[:phi0] = "phi0" + elseif feature_name == "btagJetDistVal" + deps[:dxy] = "dxy" + deps[:dz] = "dz" + deps[:phi0] = "phi0" + end + + return deps +end diff --git a/src/JetPhysicalConstants.jl b/src/JetPhysicalConstants.jl new file mode 100644 index 0000000..cfa726a --- /dev/null +++ b/src/JetPhysicalConstants.jl @@ -0,0 +1,56 @@ +""" +Physical constants and special values used in the JetFlavourTagging extension. + +This module contains all physical constants, particle properties, tolerance values, +and special markers used throughout the jet flavour tagging algorithms. +""" +module JetPhysicalConstants + +using PhysicalConstants.CODATA2018: c_0, m_e, ħ, e, k_B, N_A, R + +# Note: PhysicalConstants.jl provides values with units and uncertainties +# To extract just the numerical value, use .val +# Example: c_0.val gives the speed of light value without units + +# Physical Constants from PhysicalConstants.jl +# Extract the numerical value and convert to Float32 +const C_LIGHT = Float32(c_0.val) # Speed of light in m/s from CODATA2018 +const C_LIGHT_INV = 1.0f0 / C_LIGHT # Inverse speed of light + +# Particle Masses (in GeV/c²) +# Electron mass from CODATA2018 (convert from kg to GeV/c²) +# 1 GeV/c² = 1.78266192e-27 kg +const ELECTRON_MASS = Float32(m_e.val / 1.78266192e-27) # Electron mass in GeV/c² +# Note: Muon and pion masses are not in CODATA2018, using PDG values +const MUON_MASS = 0.105658f0 # Muon mass in GeV/c² (PDG value) +const PION_MASS = 0.13957039f0 # Charged pion mass (π±) in GeV/c² (PDG value) + +# Mass Comparison Tolerances +const ELECTRON_TOLERANCE = 1.0f-5 # Tolerance for electron mass comparison +const MUON_TOLERANCE = 1.0f-3 # Tolerance for muon mass comparison +const PION_TOLERANCE = 1.0f-3 # Tolerance for pion mass comparison + +# PDG Particle ID Codes +const PDG_PHOTON = 22 # Photon (γ) +const PDG_K_LONG = 130 # K⁰_L (K-long neutral hadron) + +# Special/Undefined Values +const UNDEF_VAL = -9.0f0 # Sentinel value for missing/invalid data +const INVALID_TOF_MASS = 9.0f0 # Invalid mass from time-of-flight +const INVALID_MASS = -1.0f0 # Invalid mass calculation marker + +# Unit Conversion Factors +const MM_TO_M = 0.001f0 # Millimeter to meter conversion +const NS_TO_S = 1.0f-9 # Nanosecond to second conversion +const PS_TO_S = 1.0f-12 # Picosecond to second conversion +const FS_TO_S = 1.0f-15 # Femtosecond to second conversion + +# Export all constants +export C_LIGHT, C_LIGHT_INV +export ELECTRON_MASS, MUON_MASS, PION_MASS +export ELECTRON_TOLERANCE, MUON_TOLERANCE, PION_TOLERANCE +export PDG_PHOTON, PDG_K_LONG +export UNDEF_VAL, INVALID_TOF_MASS, INVALID_MASS +export MM_TO_M, NS_TO_S, PS_TO_S, FS_TO_S + +end # module JetPhysicalConstants diff --git a/src/JetTaggingFCC.jl b/src/JetTaggingFCC.jl index ada99ef..1dd7aae 100644 --- a/src/JetTaggingFCC.jl +++ b/src/JetTaggingFCC.jl @@ -1,5 +1,36 @@ module JetTaggingFCC -# Write your package code here. +using JetReconstruction +using JetReconstruction: EEJet +using EDM4hep +using JSON +using ONNXRunTime +using StructArrays: StructVector +using LorentzVectorHEP -end +const JetConstituents = StructVector{ReconstructedParticle,<:Any} +const JetConstituentsData = Vector{Float32} + +# Include constituent utilities (as a module for now) +include("JetConstituentUtils.jl") +using .JetConstituentUtils + +# Include jet constituent builder functions +include("JetConstituentBuilder.jl") + +# Include flavour helper functions +include("JetFlavourHelper.jl") + +# Export all public functions +export build_constituents_cluster +export extract_features +export setup_onnx_runtime +export prepare_input_tensor +export get_weights +export get_weight + + +# Export types +export JetConstituents, JetConstituentsData + +end # module JetTaggingFCC