Skip to content

Commit ab8fa2a

Browse files
committed
try to get around np.abs not being implemented in numba yet
1 parent 039c4fb commit ab8fa2a

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/py21cmwedge/uvgridder.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def norm(x, y, out, scale_factor):
3131
out /= out_norm
3232

3333

34-
# 2-d adaptation of https://stackoverflow.com/a/70614173
3534
@njit(numba.types.UniTuple(numba.float64[:, :], 2)(numba.float64[:], numba.float64[:]))
3635
def meshgrid(x, y):
3736
xx = np.empty(shape=(x.size, y.size), dtype=x.dtype)
@@ -43,6 +42,16 @@ def meshgrid(x, y):
4342
return xx, yy
4443

4544

45+
@njit(numba.int32(numba.float64, numba.float64[:]))
46+
def get_nearest(u, u_bins):
47+
dists = u - u_bins
48+
for cnt in range(dists.size):
49+
if dists[cnt] < 0:
50+
dists[cnt] = abs(dists[cnt])
51+
52+
return dists.argmin()
53+
54+
4655
class UVGridder(object):
4756
"""Base uvgridder object."""
4857

@@ -392,8 +401,8 @@ def _weights_nearest(
392401
on self.wavelength_scale slope
393402
"""
394403
for freq_cnt, (_u, _v) in enumerate(zip(u, v)):
395-
u_index = np.abs(_u - x[:]).argmin()
396-
v_index = np.abs(_v - x[:]).argmin()
404+
u_index = get_nearest(_u, x)
405+
v_index = get_nearest(_v, x)
397406

398407
# v,u indexing because y is the outer dimension in memory
399408
uvf_cube[freq_cnt, v_index, u_index] += nbls

0 commit comments

Comments
 (0)