Skip to content

Commit 2c248d1

Browse files
authored
Merge pull request #217 from HERA-Team/refactoring_exact_norm
Refactoring exact norm
2 parents 7708f92 + 69c0d78 commit 2c248d1

File tree

2 files changed

+107
-45
lines changed

2 files changed

+107
-45
lines changed

hera_pspec/pspecdata.py

Lines changed: 78 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,10 +1100,12 @@ def q_hat(self, key1, key2, allow_fft=False, exact_norm = False, pol=False):
11001100
q = []
11011101
del_tau = np.median(np.diff(self.delays()))*1e-9 #Get del_eta in Eq.11(a) (HERA memo #44) (seconds)
11021102
Q_matrix_all_delays = np.zeros((self.spw_Ndlys,self.spw_Nfreqs,self.spw_Nfreqs), dtype='complex128')
1103+
integral_beam = self.get_integral_beam(pol) # This result does not depend on delay modes. We can remove it from the for loop to avoid its repeated computation
1104+
11031105
for i in range(self.spw_Ndlys):
11041106
# Ideally, del_tau should be part of get_Q. We use it here to
11051107
# avoid its repeated computation
1106-
Q = del_tau * self.get_Q(i, pol)
1108+
Q = del_tau * self.get_Q(i, pol) * integral_beam
11071109
Q_matrix_all_delays[i] = Q
11081110
QRx2 = np.dot(Q, Rx2)
11091111

@@ -1604,52 +1606,80 @@ def get_Q_alt(self, mode, allow_fft=True):
16041606

16051607
Q_alt = np.einsum('i,j', m.conj(), m) # dot it with its conjugate
16061608
return Q_alt
1609+
1610+
def get_integral_beam(self, pol=False):
1611+
"""
1612+
Computes the integral containing the spectral beam and tapering
1613+
function in Q_alpha(i,j).
1614+
1615+
Parameters
1616+
----------
1617+
1618+
pol : str/int/bool, optional
1619+
Which beam polarization to use. If the specified polarization
1620+
doesn't exist, a uniform isotropic beam (with integral 4pi for all
1621+
frequencies) is assumed. Default: False (uniform beam).
16071622
1623+
Return
1624+
-------
1625+
integral_beam : array_like
1626+
integral containing the spectral beam and tapering.
1627+
"""
1628+
nu = self.freqs[self.spw_range[0]:self.spw_range[1]] # in Hz
1629+
1630+
try:
1631+
# Get beam response in (frequency, pixel), beam area(freq) and
1632+
# Nside, used in computing dtheta
1633+
beam_res, beam_omega, N = \
1634+
self.primary_beam.beam_normalized_response(pol, nu)
1635+
prod = 1. / beam_omega
1636+
beam_prod = beam_res * prod[:, np.newaxis]
1637+
1638+
# beam_prod has omega subsumed, but taper is still part of R matrix
1639+
# The nside term is dtheta^2, where dtheta is the resolution in
1640+
# healpix map
1641+
integral_beam = np.pi/(3.*N*N) * np.dot(beam_prod, beam_prod.T)
1642+
1643+
except(AttributeError):
1644+
warnings.warn("The beam response could not be calculated. "
1645+
"PS will not be normalized!")
1646+
integral_beam = np.ones((len(nu), len(nu)))
1647+
1648+
return integral_beam
1649+
1650+
16081651
def get_Q(self, mode, pol=False):
1609-
'''
1610-
Computes Q_alpha(i,j), which is the response of the data covariance to the bandpower (dC/dP_alpha).
1611-
This includes contributions from primary beam.
1652+
"""
1653+
Computes Q_alpha(i,j), which is the response of the data covariance to
1654+
the bandpower (dC/dP_alpha). This includes contributions from primary
1655+
beam.
16121656
16131657
Parameters
16141658
----------
16151659
mode : int
16161660
Central wavenumber (index) of the bandpower, p_alpha.
16171661
16181662
pol : str/int/bool, optional
1619-
Which beam polarization to use. If the specified polarization doesn't exist,
1620-
a uniform isotropic beam (with integral 4pi for all frequencies) is assumed.
1621-
Default: False (uniform beam).
1663+
Which beam polarization to use. If the specified polarization
1664+
doesn't exist, a uniform isotropic beam (with integral 4pi for all
1665+
frequencies) is assumed. Default: False (uniform beam).
16221666
16231667
Return
16241668
-------
1625-
Q : array_like
1626-
Response matrix for bandpower p_alpha.
1627-
'''
1628-
1669+
Q_alt : array_like
1670+
Exponential part of Q (HERA memo #44, Eq. 11).
1671+
"""
16291672
if self.spw_Ndlys == None:
16301673
self.set_Ndlys()
16311674
if mode >= self.spw_Ndlys:
16321675
raise IndexError("Cannot compute Q matrix for a mode outside"
16331676
"of allowed range of delay modes.")
16341677
tau = self.delays()[int(mode)] * 1.0e-9 # delay in seconds
16351678
nu = self.freqs[self.spw_range[0]:self.spw_range[1]] # in Hz
1636-
1637-
try:
1638-
beam_res, beam_omega, N = self.primary_beam.beam_normalized_response(pol, nu)
1639-
#Get beam response in (frequency, pixel), beam area(freq) and Nside, used in computing dtheta.
1640-
prod = (1.0/beam_omega)
1641-
beam_prod = beam_res * prod[:, np.newaxis]
1642-
integral_beam = (np.pi/(3.0*(N)**2))* \
1643-
np.dot(beam_prod, beam_prod.T) #beam_prod has omega subsumed, but taper is still part of R matrix
1644-
# the nside terms is dtheta^2, where dtheta is the resolution in healpix map
1645-
except(AttributeError):
1646-
warnings.warn('The beam response could not be calculated. PS will not be normalized!')
1647-
integral_beam = np.ones((len(nu), len(nu)))
1648-
1649-
eta_int = np.exp(-2j * np.pi * tau * nu) #exponential part of the expression
1650-
Q_alt = np.einsum('i,j', eta_int.conj(), eta_int) # dot it with its conjugate
1651-
Q = Q_alt * integral_beam
1652-
return Q
1679+
1680+
eta_int = np.exp(-2j * np.pi * tau * nu) # exponential part
1681+
Q_alt = np.einsum('i,j', eta_int.conj(), eta_int) # dot with conjugate
1682+
return Q_alt
16531683

16541684
def p_hat(self, M, q):
16551685
"""
@@ -1673,7 +1703,8 @@ def p_hat(self, M, q):
16731703
def cov_p_hat(self, M, q_cov):
16741704
"""
16751705
Covariance estimate between two different band powers p_alpha and p_beta
1676-
given by M_{alpha i} M^*_{beta,j} C_q^{ij} where C_q^{ij} is the q-covariance
1706+
given by M_{alpha i} M^*_{beta,j} C_q^{ij} where C_q^{ij} is the
1707+
q-covariance.
16771708
16781709
Parameters
16791710
----------
@@ -1688,7 +1719,8 @@ def cov_p_hat(self, M, q_cov):
16881719
p_cov[tnum] = np.einsum('ab,cd,bd->ac', M, M, q_cov[tnum])
16891720
return p_cov
16901721

1691-
def broadcast_dset_flags(self, spw_ranges=None, time_thresh=0.2, unflag=False):
1722+
def broadcast_dset_flags(self, spw_ranges=None, time_thresh=0.2,
1723+
unflag=False):
16921724
"""
16931725
For each dataset in self.dset, update the flag_array such that
16941726
the flagging patterns are time-independent for each baseline given
@@ -1701,8 +1733,9 @@ def broadcast_dset_flags(self, spw_ranges=None, time_thresh=0.2, unflag=False):
17011733
17021734
Additionally, one can also unflag the flag_array entirely if desired.
17031735
1704-
Note: although technically allowed, this function may give unexpected results
1705-
if multiple spectral windows in spw_ranges have frequency overlap.
1736+
Note: although technically allowed, this function may give unexpected
1737+
results if multiple spectral windows in spw_ranges have frequency
1738+
overlap.
17061739
17071740
Note: it is generally not recommended to set time_thresh > 0.5, which
17081741
could lead to substantial amounts of data being flagged.
@@ -1712,11 +1745,12 @@ def broadcast_dset_flags(self, spw_ranges=None, time_thresh=0.2, unflag=False):
17121745
spw_ranges : list of tuples
17131746
list of len-2 spectral window tuples, specifying the start (inclusive)
17141747
and stop (exclusive) index of the frequency array for each spw.
1715-
Default is to use the whole band
1748+
Default is to use the whole band.
17161749
17171750
time_thresh : float
1718-
Fractional threshold of flagged pixels across time needed to flag all times
1719-
per freq channel. It is not recommend to set this greater than 0.5
1751+
Fractional threshold of flagged pixels across time needed to flag
1752+
all times per freq channel. It is not recommend to set this greater
1753+
than 0.5.
17201754
17211755
unflag : bool
17221756
If True, unflag all data in the spectral window.
@@ -1730,7 +1764,8 @@ def broadcast_dset_flags(self, spw_ranges=None, time_thresh=0.2, unflag=False):
17301764
# spw type check
17311765
if spw_ranges is None:
17321766
spw_ranges = [(0, self.Nfreqs)]
1733-
assert isinstance(spw_ranges, list), "spw_ranges must be fed as a list of tuples"
1767+
assert isinstance(spw_ranges, list), \
1768+
"spw_ranges must be fed as a list of tuples"
17341769

17351770
# iterate over datasets
17361771
for dset in self.dsets:
@@ -1740,7 +1775,7 @@ def broadcast_dset_flags(self, spw_ranges=None, time_thresh=0.2, unflag=False):
17401775
# unflag
17411776
if unflag:
17421777
# unflag for all times
1743-
dset.flag_array[:, :, self.spw_range[0]:self.spw_range[1], :] = False
1778+
dset.flag_array[:,:,self.spw_range[0]:self.spw_range[1],:] = False
17441779
continue
17451780
# enact time threshold on flag waterfalls
17461781
# iterate over polarizations
@@ -2123,8 +2158,12 @@ def pspec(self, bls1, bls2, dsets, pols, n_dlys=None,
21232158
exact_norm : bool, optional
21242159
If True, estimates power spectrum using Q instead of Q_alt
21252160
(HERA memo #44). The default options is False. Beware that
2126-
turning this True would take ~ 7 sec for computing
2127-
power spectrum for 100 channels per time sample per baseline.
2161+
turning this True would take ~ 0.2 sec for computing
2162+
power spectrum for 100 channels per time sample per baseline.
2163+
If False, computing a power spectrum for 100 channels would
2164+
take ~ 0.04 sec per time sample per baseline. This means
2165+
that computing a power spectrum when exact_norm is set to
2166+
False runs five times faster than setting it to True.
21282167
21292168
history : str, optional
21302169
history string to attach to UVPSpec object

hera_pspec/tests/test_pspecdata.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,8 @@ def test_get_Q_alt(self):
335335

336336
# Check for error handling
337337
nt.assert_raises(ValueError, self.ds.set_Ndlys, vect_length+100)
338-
338+
339+
339340
def test_get_Q(self):
340341
"""
341342
Test the Q = dC_ij/dp function.
@@ -361,9 +362,6 @@ def test_get_Q(self):
361362
key2 = (1, 24, 38)
362363
uvd = copy.deepcopy(self.uvd)
363364
ds_t = pspecdata.PSpecData(dsets=[uvd, uvd])
364-
with warnings.catch_warnings(record=True) as w:
365-
ds_t.get_Q(0, pol)
366-
assert len(w) > 0
367365

368366
for i in range(vect_length):
369367
try:
@@ -373,7 +371,7 @@ def test_get_Q(self):
373371
self.assertEqual(self.ds.spw_Ndlys, self.ds.spw_Nfreqs)
374372
except IndexError:
375373
Q_matrix = np.ones((vect_length, vect_length))
376-
374+
377375
xQy = np.dot(np.conjugate(x_vect), np.dot(Q_matrix, y_vect))
378376
yQx = np.dot(np.conjugate(y_vect), np.dot(Q_matrix, x_vect))
379377
xQx = np.dot(np.conjugate(x_vect), np.dot(Q_matrix, x_vect))
@@ -428,6 +426,31 @@ def test_get_Q(self):
428426
# of the range of delay bins
429427
nt.assert_raises(IndexError, self.ds.get_Q, vect_length-1, pol)
430428

429+
def test_get_integral_beam(self):
430+
"""
431+
Test the integral of the beam and tapering function in Q.
432+
"""
433+
pol = 'xx'
434+
#Test if there is a warning if user does not pass the beam
435+
uvd = copy.deepcopy(self.uvd)
436+
ds_t = pspecdata.PSpecData(dsets=[uvd, uvd])
437+
ds = pspecdata.PSpecData(dsets=[uvd, uvd], beam=self.bm)
438+
439+
with warnings.catch_warnings(record=True) as w:
440+
ds_t.get_integral_beam(pol)
441+
assert len(w) > 0
442+
443+
try:
444+
integral_matrix = ds.get_integral_beam(pol)
445+
# Test that if the number of delay bins hasn't been set
446+
# the code defaults to putting that equal to Nfreqs
447+
self.assertEqual(ds.spw_Ndlys, ds.spw_Nfreqs)
448+
except IndexError:
449+
integral_matrix = np.ones((ds.spw_Ndlys, ds.spw_Ndlys))
450+
451+
# Test that integral matrix has the right shape
452+
self.assertEqual(integral_matrix.shape, (ds.spw_Nfreqs, ds.spw_Nfreqs))
453+
431454
def test_get_unnormed_E(self):
432455
"""
433456
Test the E function
@@ -1180,7 +1203,7 @@ def test_pspec(self):
11801203
bls_Q = [(24, 25)]
11811204
uvp = ds_Q.pspec(bls_Q, bls_Q, (0, 1), [('xx', 'xx')], input_data_weight='identity',
11821205
norm='I', taper='none', verbose=True, exact_norm=False)
1183-
Q_sample = ds_Q.get_Q((ds_Q.spw_range[1] - ds_Q.spw_range[0])/2, 'xx') #Get Q matrix for 0th delay mode
1206+
Q_sample = ds_Q.get_integral_beam('xx') #Get integral beam for pol 'xx'
11841207

11851208
nt.assert_equal(np.shape(Q_sample), (ds_Q.spw_range[1] - ds_Q.spw_range[0],\
11861209
ds_Q.spw_range[1] - ds_Q.spw_range[0])) #Check for the right shape

0 commit comments

Comments
 (0)