Skip to content

Commit 039c4fb

Browse files
committed
break spatial functions into separate gridders
1 parent 5bc07dd commit 039c4fb

File tree

2 files changed

+131
-59
lines changed

2 files changed

+131
-59
lines changed

src/py21cmwedge/uvgridder.py

Lines changed: 130 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,24 @@
1111
from . import dft
1212

1313

14-
@njit(numba.float64[:, :](numba.float64[:, :], numba.float64[:, :]))
15-
def norm(x, y):
16-
out_array = np.zeros_like(x, dtype=x.dtype)
17-
_out = out_array.ravel()
18-
for cnt, (_x, _y) in enumerate(zip(x.flat, y.flat)):
19-
_out[cnt] += np.sqrt(_x**2 + _y**2)
20-
return out_array
14+
@njit(
15+
numba.void(
16+
numba.float64[:, :], numba.float64[:, :], numba.float64[:, :], numba.float64
17+
)
18+
)
19+
def norm(x, y, out, scale_factor):
20+
for cnt1 in range(out.shape[0]):
21+
for cnt2 in range(out.shape[1]):
22+
out[cnt1, cnt2] = max(
23+
np.float64(1.0)
24+
- np.sqrt(x[cnt1, cnt2] ** 2 + y[cnt1, cnt2] ** 2) / scale_factor,
25+
np.float64(0),
26+
)
27+
28+
out_norm = out.sum()
29+
30+
if out_norm > 0:
31+
out /= out_norm
2132

2233

2334
# 2-d adaptation of https://stackoverflow.com/a/70614173
@@ -293,22 +304,26 @@ def to_str(arr):
293304
numba.float64[:],
294305
numba.float64[:],
295306
numba.int32,
296-
numba.complex128[:, :, :],
307+
numba.float64[:, :, :],
297308
numba.int32,
298309
numba.float64,
299310
numba.float64,
300-
numba.types.string,
311+
numba.float64[:, :],
312+
numba.float64[:, :],
313+
numba.float64[:, :],
301314
)
302315
)
303-
def uv_weights(
316+
def _weights_tri(
304317
u,
305318
v,
306319
nbls,
307320
uvf_cube,
308321
uv_size,
309322
uv_delta,
310323
wavelength_scale,
311-
spatial_function="triangle",
324+
x,
325+
y,
326+
weights,
312327
):
313328
"""Compute weights for arbitrary baseline on a gridded UV plane.
314329
@@ -324,46 +339,64 @@ def uv_weights(
324339
triangle performs simple distance based weighting of uv-bins based
325340
on self.wavelength_scale slope
326341
"""
327-
match spatial_function.casefold():
328-
case "triangle":
329-
_range = np.arange(uv_size) - (uv_size - 1) / 2.0
330-
_range *= uv_delta
331-
x, y = meshgrid(_range, _range)
332-
for freq_cnt, (_u, _v) in enumerate(zip(u, v)):
333-
_x = _u - x
334-
_y = _v - y
335-
dists = norm(_x, _y)
336-
weights = 1.0 - dists / wavelength_scale
337-
_w = weights.ravel()
338-
for cnt, __w in enumerate(_w):
339-
if __w < 0:
340-
_w[cnt] = 0
341-
# weights[weights <= 0] = 0
342-
# weights = np.ma.masked_less_equal(weights, 0).filled(0)
343-
weights /= weights.sum()
344-
345-
uvf_cube[freq_cnt] += weights * nbls
346342

347-
case "nearest":
348-
_range = np.arange(uv_size) - (uv_size - 1) / 2.0
349-
_range *= uv_delta
350-
x, y = _range, _range
351-
x = np.expand_dims(x, -1)
352-
y = np.expand_dims(y, -1)
353-
for freq_cnt, (_u, _v) in enumerate(zip(u, v)):
354-
_x = _u - x
355-
_y = _v - y
343+
for freq_cnt, (_u, _v) in enumerate(zip(u, v)):
344+
_x = _u - x[:, :]
345+
_y = _v - y[:, :]
356346

357-
u_index = np.abs(_x).argmin()
358-
v_index = np.abs(_y).argmin()
347+
norm(_x, _y, weights, wavelength_scale)
359348

360-
# v,u indexing because y is the outer dimension in memory
361-
uvf_cube[freq_cnt, v_index, u_index] += 1.0 * nbls
349+
# _w = weights.ravel()
350+
# for cnt, __w in enumerate(_w):
351+
# if __w < 0:
352+
# _w[cnt] = 0
353+
# # weights[weights <= 0] = 0
354+
# # weights = np.ma.masked_less_equal(weights, 0).filled(0)
355+
# weights /= weights.sum()
362356

363-
case _:
364-
raise ValueError(
365-
f"Unknown value for 'spatial_function': {spatial_function}"
366-
)
357+
uvf_cube[freq_cnt, :, :] += weights[:, :] * nbls
358+
359+
@staticmethod
360+
@numba.njit(
361+
numba.void(
362+
numba.float64[:],
363+
numba.float64[:],
364+
numba.int32,
365+
numba.float64[:, :, :],
366+
numba.int32,
367+
numba.float64,
368+
numba.float64[:],
369+
),
370+
)
371+
def _weights_nearest(
372+
u,
373+
v,
374+
nbls,
375+
uvf_cube,
376+
uv_size,
377+
uv_delta,
378+
x,
379+
):
380+
"""Compute weights for arbitrary baseline on a gridded UV plane.
381+
382+
uv must be in units of pixels.
383+
384+
Parameters
385+
----------
386+
convolve_beam: bool
387+
when set to true, perform an FFT convolution with the supplied beam
388+
spatial_function: string
389+
must be one of ["nearest", "triangle"].
390+
Nearest modes performs delta function like assignment into a uv-bin
391+
triangle performs simple distance based weighting of uv-bins based
392+
on self.wavelength_scale slope
393+
"""
394+
for freq_cnt, (_u, _v) in enumerate(zip(u, v)):
395+
u_index = np.abs(_u - x[:]).argmin()
396+
v_index = np.abs(_v - x[:]).argmin()
397+
398+
# v,u indexing because y is the outer dimension in memory
399+
uvf_cube[freq_cnt, v_index, u_index] += nbls
367400

368401
def __sum_uv__(self, uv_key, spatial_function="triangle"):
369402
"""Convert uvbin dictionary to a UV-plane.
@@ -381,16 +414,35 @@ def __sum_uv__(self, uv_key, spatial_function="triangle"):
381414
u /= self.wavelength
382415
v /= self.wavelength
383416
# Create interpolation weights based on grid size and sampling
384-
UVGridder.uv_weights(
385-
u,
386-
v,
387-
nbls,
388-
self.uvf_cube,
389-
self.uv_size,
390-
self.uv_delta,
391-
self.wavelength_scale,
392-
spatial_function=spatial_function,
393-
)
417+
match spatial_function:
418+
case "triangle":
419+
UVGridder._weights_tri(
420+
u,
421+
v,
422+
nbls,
423+
self.uvf_cube,
424+
self.uv_size,
425+
self.uv_delta,
426+
self.wavelength_scale,
427+
self.x,
428+
self.y,
429+
self.weights,
430+
)
431+
case "nearest":
432+
UVGridder._weights_nearest(
433+
u,
434+
v,
435+
nbls,
436+
self.uvf_cube,
437+
self.uv_size,
438+
self.uv_delta,
439+
self.x,
440+
)
441+
case _:
442+
raise ValueError(
443+
f"Unknown spatial_function ({spatial_function}), must be "
444+
'"nearest" or "triangle"'
445+
)
394446

395447
def grid_uvw(self, convolve_beam=True, spatial_function="triangle"):
396448
"""Create UV coverage from object data.
@@ -410,13 +462,33 @@ def grid_uvw(self, convolve_beam=True, spatial_function="triangle"):
410462
* 2
411463
+ 5
412464
)
465+
466+
match spatial_function.casefold():
467+
case "triangle":
468+
_range = np.arange(self.uv_size) - (self.uv_size - 1) / 2.0
469+
_range *= self.uv_delta
470+
self.x, self.y = meshgrid(_range, _range)
471+
self.weights = np.zeros_like(self.x)
472+
473+
case "nearest":
474+
_range = np.arange(self.uv_size) - (self.uv_size - 1) / 2.0
475+
_range *= self.uv_delta
476+
self.x = _range
477+
case _:
478+
raise ValueError(
479+
f"Unknown value for 'spatial_function': {spatial_function}"
480+
)
481+
413482
self.uvf_cube = np.zeros(
414-
(self.freqs.size, self.uv_size, self.uv_size), dtype=complex
483+
(self.freqs.size, self.uv_size, self.uv_size), dtype=np.float64
415484
)
485+
416486
for uv_key in tqdm(self.uvbins.keys(), unit="Baseline"):
417487
self.__sum_uv__(uv_key, spatial_function=spatial_function)
418488

419489
if convolve_beam:
490+
self.uvf_cube = self.uvf_cube.astype(np.complex128)
491+
420492
beam_array = self.get_uv_beam()
421493
# if only one beam was given, use that beam for all freqs
422494
if np.shape(beam_array)[0] < self.freqs.size:

tests/test_uvgridder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def test_weights_sum():
314314
test_obj = UVGridder()
315315
test_obj.set_uv_delta(0.5)
316316
test_obj.set_freqs(150e6)
317-
test_obj.uv_delta = 0.1
317+
test_obj.uv_delta = 0.5
318318
test_uvw = np.array([[14.6], [0], [0]])
319319
test_obj.set_uvw_array(test_uvw)
320320
test_obj.calc_all(convolve_beam=False)

0 commit comments

Comments
 (0)