2121logger = logging .getLogger (__name__ )
2222
2323
24- def weights_gaussian (
25- distances : np .ndarray ,
26- radius : float ,
27- exponent : float = 2.0 ,
28- normalize : bool = True ,
29- ) -> np .ndarray :
30- weights = np .exp (- ((distances / radius ) ** exponent ))
31- if normalize :
32- weights /= weights .max ()
33- return weights
34-
35-
3624def weights_exponential (
3725 distances : np .ndarray ,
3826 radius : float ,
@@ -46,10 +34,9 @@ def weights_exponential(
4634 return weights
4735
4836
49- def weights_gaussian_min_stations (
37+ def weights_gaussian (
5038 distances : np .ndarray ,
51- radius : float ,
52- exponent : float = 2.0 ,
39+ distance_taper : float ,
5340 required_stations : int = 4 ,
5441 waterlevel : float = 0.0 ,
5542) -> np .ndarray :
@@ -60,40 +47,35 @@ def weights_gaussian_min_stations(
6047 sorted_distances = np .sort (distances , axis = 1 )
6148 threshold_distance = sorted_distances [:, required_stations - 1 , np .newaxis ]
6249
63- weights = np .exp (- (((distances - threshold_distance ) / radius ) ** exponent ))
50+ weights = np .exp (
51+ - (((distances - threshold_distance ) ** 2 ) / (2 * (distance_taper / 2 ) ** 2 ))
52+ )
6453 weights [distances <= threshold_distance ] = 1.0
6554 if waterlevel > 0.0 :
6655 weights = (1 - waterlevel ) * weights + waterlevel
6756 return weights
6857
6958
7059class DistanceWeights (BaseModel ):
71- radius_meters : PositiveFloat | Literal ["mean_interstation" ] = Field (
60+ distance_taper : PositiveFloat | Literal ["mean_interstation" ] = Field (
7261 default = "mean_interstation" ,
73- description = "Cutoff distance for the spatial decay function in meters. "
74- " 'mean_interstation' uses the mean interstation distance for the radius. "
75- " Default is 'mean_interstation'." ,
62+ description = "Taper distance for Gaussian the distance weighting function"
63+ " in meters. 'mean_interstation' uses twice the mean interstation distance for"
64+ " the radius. Default is 'mean_interstation'." ,
7665 )
77- min_required_stations : PositiveInt = Field (
66+ required_closest_stations : PositiveInt = Field (
7867 default = 4 ,
79- description = "Minimum number of stations to assign full weight in the"
80- " exponential decay function. Default is 4." ,
81- )
82- exponent : float = Field (
83- default = 2.0 ,
84- description = "Exponent of the spatial decay function. For 'gaussian' decay an"
85- " exponent of 0.5 is recommended. Default is 2." ,
86- ge = 0.0 ,
68+ description = "Number of stations to assign full weight in the"
69+ " spatial weighting function, only more distant stations are tapered with a"
70+ " Gaussian decay. This ensures that the closest _N_ stations have an equal and"
71+ " the highest contribution to the detection and localization. Default is 4." ,
8772 )
8873 waterlevel : float = Field (
8974 default = 0.0 ,
9075 ge = 0.0 ,
9176 le = 1.0 ,
92- description = "Waterlevel for the exponential decay function. Default is 0.0." ,
93- )
94- normalize : bool = Field (
95- default = True ,
96- description = "Normalize the weights to the range [0, 1]. Default is True." ,
77+ description = "Stations outside the taper are lifted by this fraction. "
78+ "Default is 0.0." ,
9779 )
9880
9981 _node_lut : ArrayLRUCache [bytes ] = PrivateAttr ()
@@ -110,14 +92,17 @@ def get_distances(self, nodes: Sequence[Node]) -> np.ndarray:
11092 )
11193
11294 def prepare (self , stations : StationInventory , octree : Octree ) -> None :
113- logger .info ("preparing distance weights" )
114-
115- if self .radius_meters == "mean_interstation" :
116- self .radius_meters = stations .mean_interstation_distance ()
95+ if self .distance_taper == "mean_interstation" :
96+ self .distance_taper = 2 * stations .mean_interstation_distance ()
11797 logger .info (
118- "using mean interstation distance as radius : %g m" ,
119- self .radius_meters ,
98+ "using 2x mean interstation distance as distance taper : %g m" ,
99+ self .distance_taper ,
120100 )
101+ logger .info (
102+ "distance weighting uses %d closest stations and a taper of %g m" ,
103+ self .required_closest_stations ,
104+ self .distance_taper ,
105+ )
121106
122107 self ._stations = StationList .from_inventory (stations )
123108 self ._node_lut = ArrayLRUCache (name = "distance_weights" , short_name = "DW" )
@@ -144,11 +129,10 @@ def get_node_weights(self, node: Node, stations: list[Station]) -> np.ndarray:
144129 self .fill_lut ([node ])
145130 return self .get_node_weights (node , stations )
146131
147- return weights_gaussian_min_stations (
132+ return weights_gaussian (
148133 distances ,
149- required_stations = self .min_required_stations ,
150- radius = self .radius_meters ,
151- exponent = self .exponent ,
134+ required_stations = self .required_closest_stations ,
135+ distance_taper = self .distance_taper ,
152136 waterlevel = self .waterlevel ,
153137 )
154138
@@ -163,11 +147,10 @@ async def get_weights(
163147
164148 try :
165149 distances = [node_lut [node .hash ][station_indices ] for node in nodes ]
166- return weights_gaussian_min_stations (
150+ return weights_gaussian (
167151 np .array (distances ),
168- required_stations = self .min_required_stations ,
169- radius = self .radius_meters ,
170- exponent = self .exponent ,
152+ required_stations = self .required_closest_stations ,
153+ distance_taper = self .distance_taper ,
171154 waterlevel = self .waterlevel ,
172155 )
173156 except KeyError :
0 commit comments