Skip to content

Commit 1e96449

Browse files
committed
Make PDFs more explicit
One of the criticistms of simweights is that it is hard to find the weighting functions. This was because the weighting was a list of objects and the weights were calculating by taking the product of them. This PR refactors it so that each type of simulation has a class that is derived from GenerationSurface that overrides get_epdf(). This is a clear function that is 1/weight. However since it makes use of Spatial and PowerLaw it is not as explicit as it could be but DRY. As a side effecth Spatial disturcitions and PowerLaw no longer have column names as a member variable which i thought was just weird.
1 parent 2282385 commit 1e96449

21 files changed

+510
-526
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ repos:
2828
- id: blacken-docs
2929
args: [-l 100]
3030
- repo: https://github.com/pre-commit/mirrors-mypy
31-
rev: v1.16.1
31+
rev: v1.17.1
3232
hooks:
3333
- id: mypy
3434
files: simweights
@@ -42,7 +42,7 @@ repos:
4242
exclude: ^contrib/
4343
additional_dependencies: [numpy, pandas]
4444
- repo: https://github.com/astral-sh/ruff-pre-commit
45-
rev: v0.12.2
45+
rev: v0.12.7
4646
hooks:
4747
- id: ruff
4848
args: [--fix, --show-fixes]

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ max-line-length = 128
7070

7171
[tool.mypy]
7272
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
73-
plugins = "numpy.typing.mypy_plugin"
7473
strict = true
7574
warn_unreachable = true
7675

@@ -84,7 +83,7 @@ disable = "C0114,R0902,R0913,R0917,R0914,R0911"
8483
addopts = ["-ra", "--strict-config", "--strict-markers", "--cov=simweights", "-W ignore"]
8584
filterwarnings = ["error"]
8685
log_cli_level = "INFO"
87-
minversion = 7.0
86+
minversion = "7.0"
8887
testpaths = ["tests"]
8988
xfail_strict = true
9089

@@ -93,7 +92,7 @@ line-length = 128
9392
namespace-packages = ["examples", "contrib", "docs"]
9493

9594
[tool.ruff.lint]
96-
fixable = ["I"]
95+
fixable = ["I", "Q"]
9796
ignore = [
9897
"ANN401", # any-type
9998
"S101", # assert-used

src/simweights/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"GaisserH4a_IT",
2525
"GaisserHillas",
2626
"GenerationSurface",
27+
"GenieSurface",
2728
"GenieWeighter",
2829
"GlobalFitGST",
2930
"GlobalFitGST_IT",
@@ -34,15 +35,16 @@
3435
"Hoerandel5",
3536
"Hoerandel_IT",
3637
"Honda2004",
38+
"IceTopSurface",
3739
"IceTopWeighter",
3840
"NaturalRateCylinder",
41+
"NuGenSurface",
3942
"NuGenWeighter",
4043
"PDGCode",
4144
"PowerLaw",
4245
"UniformSolidAngleCylinder",
4346
"Weighter",
4447
"corsika_to_pdg",
45-
"generation_surface",
4648
]
4749

4850
from ._corsika_weighter import CorsikaWeighter
@@ -63,10 +65,10 @@
6365
Hoerandel_IT,
6466
Honda2004,
6567
)
66-
from ._generation_surface import GenerationSurface, generation_surface
67-
from ._genie_weighter import GenieWeighter
68-
from ._icetop_weighter import IceTopWeighter
69-
from ._nugen_weighter import NuGenWeighter
68+
from ._generation_surface import GenerationSurface
69+
from ._genie_weighter import GenieSurface, GenieWeighter
70+
from ._icetop_weighter import IceTopSurface, IceTopWeighter
71+
from ._nugen_weighter import NuGenSurface, NuGenWeighter
7072
from ._pdgcode import PDGCode
7173
from ._powerlaw import PowerLaw
7274
from ._spatial import CircleInjector, NaturalRateCylinder, UniformSolidAngleCylinder

src/simweights/_corsika_weighter.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,40 @@
66

77
import numbers
88
import warnings
9-
from typing import Any
9+
from typing import TYPE_CHECKING, Any, Mapping
1010

1111
import numpy as np
1212

13-
from ._generation_surface import GenerationSurface, generation_surface
13+
from ._generation_surface import CompositeSurface, GenerationSurface
1414
from ._powerlaw import PowerLaw
1515
from ._spatial import NaturalRateCylinder
16-
from ._utils import Column, constcol, get_column, get_table, has_column, has_table
16+
from ._utils import constcol, get_column, get_table, has_column, has_table
1717
from ._weighter import Weighter
1818

19+
if TYPE_CHECKING:
20+
from numpy.typing import NDArray
1921

20-
def sframe_corsika_surface(table: Any) -> GenerationSurface:
22+
23+
class CorsikaSurface(GenerationSurface):
24+
"""Represents a surface on which CORSIKA simulation was generated on."""
25+
26+
def get_epdf(self: CorsikaSurface, weight_cols: Mapping[str, NDArray[np.float64]]) -> NDArray[np.float64]:
27+
"""Get the extended pdf of a sample of CORSIKA."""
28+
return (
29+
self.nevents
30+
/ weight_cols["event_weight"]
31+
* self.power_law.pdf(weight_cols["energy"])
32+
* self.spatial.pdf(weight_cols["cos_zen"])
33+
)
34+
35+
36+
def sframe_corsika_surface(table: Any) -> CompositeSurface:
2137
"""Inspect the rows of a CORSIKA S-Frame table object to generate a surface object.
2238
2339
This function works on files generated with either triggered CORSIKA or corsika-reader because
2440
`I3PrimaryInjectorInfo` and `I3CorsikaInfo` use exactly the same names for quantities.
2541
"""
26-
surfaces = []
42+
surfaces = CompositeSurface()
2743
cylinder_height = get_column(table, "cylinder_height")
2844
cylinder_radius = get_column(table, "cylinder_radius")
2945
max_zenith = get_column(table, "max_zenith")
@@ -39,32 +55,26 @@ def sframe_corsika_surface(table: Any) -> GenerationSurface:
3955
cylinder_radius[i],
4056
np.cos(max_zenith[i]),
4157
np.cos(min_zenith[i]),
42-
"cos_zen",
4358
)
4459
spectrum = PowerLaw(
4560
power_law_index[i],
4661
min_energy[i],
4762
max_energy[i],
48-
"energy",
4963
)
5064
oversampling_val = get_column(table, "oversampling")[i] if has_column(table, "oversampling") else 1
5165
pdgid = int(get_column(table, "primary_type")[i])
52-
surfaces.append(
53-
n_events[i] * oversampling_val * generation_surface(pdgid, Column("event_weight"), spectrum, spatial),
54-
)
55-
retval = sum(surfaces)
56-
assert isinstance(retval, GenerationSurface)
57-
return retval
66+
surfaces.insert(CorsikaSurface(pdgid, n_events[i] * oversampling_val, spectrum, spatial))
67+
return surfaces
5868

5969

60-
def weight_map_corsika_surface(table: Any) -> GenerationSurface:
70+
def weight_map_corsika_surface(table: Any) -> CompositeSurface:
6171
"""Inspect the `CorsikaWeightMap` table object of a corsika file to generate a surface object."""
6272
pdgids = sorted(np.unique(get_column(table, "ParticleType").astype(int)))
6373

6474
if len(pdgids) == 0:
6575
msg = "`CorsikaWeightMap` is empty. SimWeights cannot process this file"
6676
raise RuntimeError(msg)
67-
surface: int | GenerationSurface = 0
77+
surface = CompositeSurface()
6878
for pdgid in pdgids:
6979
mask = pdgid == get_column(table, "ParticleType")
7080

@@ -73,7 +83,6 @@ def weight_map_corsika_surface(table: Any) -> GenerationSurface:
7383
constcol(table, "CylinderRadius", mask),
7484
np.cos(constcol(table, "ThetaMax", mask)),
7585
np.cos(constcol(table, "ThetaMin", mask)),
76-
"cos_zen",
7786
)
7887

7988
primary_spectral_index = round(constcol(table, "PrimarySpectralIndex", mask), 6)
@@ -83,11 +92,9 @@ def weight_map_corsika_surface(table: Any) -> GenerationSurface:
8392
primary_spectral_index,
8493
constcol(table, "EnergyPrimaryMin", mask),
8594
constcol(table, "EnergyPrimaryMax", mask),
86-
"energy",
8795
)
8896
nevents = constcol(table, "OverSampling", mask) * constcol(table, "NEvents", mask)
89-
surface += nevents * generation_surface(pdgid, spectrum, spatial)
90-
assert isinstance(surface, GenerationSurface)
97+
surface.insert(CorsikaSurface(pdgid, nevents, spectrum, spatial))
9198
return surface
9299

93100

@@ -145,7 +152,8 @@ def CorsikaWeighter(file_obj: Any, nfiles: float | None = None) -> Weighter: #
145152
)
146153

147154
table = get_table(file_obj, "CorsikaWeightMap")
148-
surface = nfiles * weight_map_corsika_surface(table)
155+
surface = weight_map_corsika_surface(table)
156+
surface.scale(nfiles)
149157
triggered = False
150158

151159
weighter = Weighter([file_obj], surface)

0 commit comments

Comments
 (0)