88from __future__ import annotations
99
1010from dataclasses import dataclass
11+ import importlib
1112from typing import ClassVar , TYPE_CHECKING
1213
1314import numpy as np
14- from numpy .polynomial .polynomial import Polynomial
1515
1616from .model_base import InstrumentModel , ModelData
1717from .mixins import SimpleBroaden1DMixin
18- from ..instrument import INSTRUMENT_DATA_PATH
1918
2019if TYPE_CHECKING :
2120 from jaxtyping import Float
@@ -67,16 +66,14 @@ class ScaledTabulatedModel(SimpleBroaden1DMixin, InstrumentModel):
6766 """
6867 input = ('energy_transfer' ,)
6968
70- data_class : ClassVar [type [ScaledTabultedModelData ]] = ScaledTabulatedModelData
69+ data_class : ClassVar [type [ScaledTabulatedModelData ]] = ScaledTabulatedModelData
7170
7271 def __init__ (self , model_data : ScaledTabulatedModelData , ** _ ):
73- from pathlib import Path
74-
7572 from numpy .polynomial import Polynomial
7673 from scipy .interpolate import RegularGridInterpolator
7774
7875 super ().__init__ (model_data )
79- self .data = np .load (Path ( INSTRUMENT_DATA_PATH ) / model_data .npz )
76+ self .data = np .load (importlib . resources . files ( "resins.instrument_data" ) / model_data .npz )
8077
8178 self .polynomial = Polynomial (coef = self .data ["coef" ],
8279 domain = self .data ["domain" ],
@@ -110,7 +107,7 @@ def get_characteristics(self, omega_q: Float[np.ndarray, 'energy_transfer dimens
110107 return {'sigma' : self .polynomial (omega_q [:, 0 ])}
111108
112109 def get_kernel (self ,
113- omega_q : Float [np .ndarray , 'sample dimension=1' ],
110+ points : Float [np .ndarray , 'sample dimension=1' ],
114111 mesh : Float [np .ndarray , 'mesh' ],
115112 ) -> Float [np .ndarray , 'sample mesh' ]:
116113
@@ -130,14 +127,14 @@ def get_kernel(self,
130127 return interp_kernels
131128
132129 def get_peak (self ,
133- omega_q : Float [np .ndarray , 'sample dimension=1' ],
130+ points : Float [np .ndarray , 'sample dimension=1' ],
134131 mesh : Float [np .ndarray , 'mesh' ],
135132 ) -> Float [np .ndarray , 'sample mesh' ]:
136133 shifted_meshes = [mesh - energy for energy in omega_q [:, 0 ]]
137134
138135 shifted_kernels = [
139- self .get_kernel (np .array ([omega_q ]), shifted_mesh )
140- for omega_q , shifted_mesh in zip (omega_q , shifted_meshes )
136+ self .get_kernel (np .array ([point ]), shifted_mesh )
137+ for point , shifted_mesh in zip (points , shifted_meshes )
141138 ]
142139
143140 return np .array (np .vstack (shifted_kernels ))
0 commit comments