-
Notifications
You must be signed in to change notification settings - Fork 4
updates to incorporate the matrix chunking method #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2b398d1
64b5220
2a4a574
e3690c3
89e52f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,3 +50,38 @@ def compute(self, z: np.ndarray, out: np.ndarray) -> np.ndarray: | |
| out[i] = z[ai].conj().dot(z[aj].T) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| class CPUMatChunk(MatProd): | ||
steven-murray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Loop over a small set of sub-matrix products which collectively contain all nont-redundant pairs.""" | ||
|
|
||
| def compute(self, z: np.ndarray, out: np.ndarray) -> np.ndarray: | ||
| """Perform the source-summing operation for a single time and chunk. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| z | ||
| Complex integrand. Shape=(Nfeed, Nant, Nax, Nsrc). | ||
| out | ||
| Output array, shaped as (Nfeed, Nfeed, Npairs). | ||
| """ | ||
| z = z.reshape((self.nant, self.nfeed, -1)) | ||
|
|
||
| mat_product = np.zeros( | ||
| (self.nant, self.nant, self.nfeed, self.nfeed), dtype=z.dtype | ||
| ) | ||
|
Comment on lines
+70
to
+72
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it not be faster to just save Nsets arrays in a list, and then index from them? This way, you have to put the results into a big array so they're very separated in memory space
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, setup something like this: mat_product = [
np.zeros(len(ai), len(aj), self.nfeed, self.nfeed) for ai, aj in self.matsets
]
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason I didn't do it this way is because I wanted the final array that holds all the matrix products to be indexable with the elements of antpairs. That way, I can easily grab the non-redundant pairs at the end of return an array with the same format that you'd get using the vector-vector method. Otherwise, I have to re-calculate the mapping the mapping to the antpairs indices for each sub-matrix. I can change it if you think it would help with memory management to do it the other way though. |
||
|
|
||
| # Chris 12/20/23: instead we will use matsets | ||
| for j in range(self.nfeed): | ||
| for k in range(self.nfeed): | ||
| for i, (ai, aj) in enumerate(self.matsets): | ||
| AI, AJ = np.meshgrid(ai, aj) | ||
| mat_product[AI, AJ, j, k] = z[ai[:], j].conj().dot(z[aj[:], k].T).T | ||
|
|
||
| # Now, we need to identify the non-redundant pairs and put them into the final output array | ||
| for j in range(self.nfeed): | ||
| for k in range(self.nfeed): | ||
| for i, (ai, aj) in enumerate(self.antpairs): | ||
| out[i, j, k] = mat_product[ai, aj, j, k] | ||
|
|
||
| return out | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| """Module containing several options for computing sub-matrices for the MatChunk method.""" | ||
| import numpy as np | ||
|
|
||
|
|
||
| def get_matrix_sets(bls, ndecimals: int = 2): | ||
| """Find redundant baselines.""" | ||
| uvbins = set() | ||
| msets = [] | ||
|
|
||
| # Everything here is in wavelengths | ||
| bls = np.round(bls, decimals=ndecimals) | ||
| nant = bls.shape[0] | ||
|
|
||
| # group redundant baselines | ||
| for i in range(nant): | ||
| for j in range(i + 1, nant): | ||
| u, v = bls[i, j] | ||
| if (u, v) not in uvbins and (-u, -v) not in uvbins: | ||
| uvbins.add((u, v)) | ||
| msets.append([np.array([i]), np.array([j])]) | ||
|
|
||
| return msets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you could simply do this as
antpairs_set = set(self.antpairs)