-
Notifications
You must be signed in to change notification settings - Fork 6
Initial implementation of row/column blocked MPIMatrixMult #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
7fcc2cf
Added impl, test and example
astroC86 6a9d382
Merge branch 'PyLops:main' into astroC86-SUMMA
astroC86 f72fce6
Addressed some comments
astroC86 c607283
Example formating
astroC86 de1a173
Rename MatrixMultiply file to MatrixMult
astroC86 82b7e34
Addressed more issues
astroC86 9e1a49f
Addressed comments
astroC86 22cde7b
Addressing changes
astroC86 740030d
Minor cosmetic changes
astroC86 a88dec3
More minor changes
astroC86 66f1770
Example shape dims general
astroC86 7ac593d
Added comments to example
astroC86 8a56096
I donot know why I thought I needed to batch
astroC86 42452a1
Inital docstring for matrix mult
astroC86 a110ff8
minor: cleanup of docstrings and updated example
mrava87 bd9ad37
minor: fix mistake in plot_matrixmult
mrava87 18db078
removed now useless bcast and fixed mask in test
astroC86 ef3c283
changed tests
astroC86 4e39068
Fixed tests and moved checks to root
astroC86 ed3b585
Fix internal check for MPIMatrixMult
astroC86 7b76f96
Fixed Notation
astroC86 3e9659e
Skipping test if number of procs is not square for now
astroC86 dd9b43c
Merge branch 'main' into astroC86-SUMMA
astroC86 a85e75a
Fixed Doc error
astroC86 b7e6702
Renamed layer and group as to row and col respectively
astroC86 ae5661b
minor: small improvements to text
mrava87 053e52d
minor: fix flake8
mrava87 9aedd7c
MatrixMul works with non-square prcs by creating square subcommunicator
astroC86 4c662d6
minor: stylistic fixes
mrava87 0c34b78
minor: fix flake8
mrava87 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| import sys | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import math | ||
| import numpy as np | ||
| from mpi4py import MPI | ||
|
|
||
| from pylops_mpi import DistributedArray, Partition | ||
| from pylops_mpi.basicoperators.MatrixMultiply import SUMMAMatrixMult | ||
|
|
||
| np.random.seed(42) | ||
|
|
||
| comm = MPI.COMM_WORLD | ||
| rank = comm.Get_rank() | ||
| nProcs = comm.Get_size() | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| P_prime = int(math.ceil(math.sqrt(nProcs))) | ||
| C = int(math.ceil(nProcs / P_prime)) | ||
| assert P_prime * C >= nProcs | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # matrix dims | ||
| M = 32 # any M | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| K = 32 # any K | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| N = 35 # any N | ||
|
|
||
| blk_rows = int(math.ceil(M / P_prime)) | ||
| blk_cols = int(math.ceil(N / P_prime)) | ||
|
|
||
| my_group = rank % P_prime | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| my_layer = rank // P_prime | ||
|
|
||
| # sub‐communicators | ||
| layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group | ||
|
|
||
| # Each rank will end up with: | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # A_p: shape (my_own_rows, K) | ||
| # B_p: shape (K, my_own_cols) | ||
| # where | ||
| row_start = my_group * blk_rows | ||
| row_end = min(M, row_start + blk_rows) | ||
| my_own_rows = row_end - row_start | ||
|
|
||
| col_start = my_group * blk_cols # note: same my_group index on cols | ||
| col_end = min(N, col_start + blk_cols) | ||
| my_own_cols = col_end - col_start | ||
|
|
||
| # ======================= BROADCASTING THE SLICES ======================= | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if rank == 0: | ||
| A = np.arange(M*K, dtype=np.float32).reshape(M, K) | ||
| B = np.arange(K*N, dtype=np.float32).reshape(K, N) | ||
| for dest in range(nProcs): | ||
| pg = dest % P_prime | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| rs = pg*blk_rows; re = min(M, rs+blk_rows) | ||
| cs = pg*blk_cols; ce = min(N, cs+blk_cols) | ||
| a_block , b_block = A[rs:re, :].copy(), B[:, cs:ce].copy() | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if dest == 0: | ||
| A_p, B_p = a_block, b_block | ||
| else: | ||
| comm.Send(a_block, dest=dest, tag=100+dest) | ||
| comm.Send(b_block, dest=dest, tag=200+dest) | ||
| else: | ||
| A_p = np.empty((my_own_rows, K), dtype=np.float32) | ||
| B_p = np.empty((K, my_own_cols), dtype=np.float32) | ||
| comm.Recv(A_p, source=0, tag=100+rank) | ||
| comm.Recv(B_p, source=0, tag=200+rank) | ||
|
|
||
| comm.Barrier() | ||
|
|
||
| Aop = SUMMAMatrixMult(A_p, N) | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| col_lens = comm.allgather(my_own_cols) | ||
| total_cols = np.add.reduce(col_lens, 0) | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| x = DistributedArray(global_shape=K * total_cols, | ||
| local_shapes=[K * col_len for col_len in col_lens], | ||
| partition=Partition.SCATTER, | ||
| mask=[i % P_prime for i in range(comm.Get_size())], | ||
| dtype=np.float32) | ||
| x[:] = B_p.flatten() | ||
| y = Aop @ x | ||
|
|
||
| # ======================= VERIFICATION =================-============= | ||
| A = np.arange(M*K).reshape(M, K).astype(np.float32) | ||
| B = np.arange(K*N).reshape(K, N).astype(np.float32) | ||
| C_true = A @ B | ||
| Z_true = (A.T.dot(C_true.conj())).conj() | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| col_start = my_layer * blk_cols # note: same my_group index on cols | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| col_end = min(N, col_start + blk_cols) | ||
| my_own_cols = col_end - col_start | ||
| expected_y = C_true[:,col_start:col_end].flatten() | ||
|
|
||
| if not np.allclose(y.local_array, expected_y, atol=1e-6, rtol=1e-14): | ||
| print(f"RANK {rank}: FORWARD VERIFICATION FAILED") | ||
| print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}') | ||
| else: | ||
| print(f"RANK {rank}: FORWARD VERIFICATION PASSED") | ||
|
|
||
| z = Aop.H @ y | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| expected_z = Z_true[:,col_start:col_end].flatten() | ||
| if not np.allclose(z.local_array, expected_z, atol=1e-6, rtol=1e-14): | ||
| print(f"RANK {rank}: ADJOINT VERIFICATION FAILED") | ||
| print(f'{rank} local: {z.local_array}, expected: {Z_true[:,col_start:col_end]}') | ||
| else: | ||
| print(f"RANK {rank}: ADJOINT VERIFICATION PASSED") | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,142 @@ | ||
| import numpy as np | ||
| import math | ||
| from mpi4py import MPI | ||
| from pylops.utils.backend import get_module | ||
| from pylops.utils.typing import DTypeLike, NDArray | ||
|
|
||
| from pylops_mpi import ( | ||
| DistributedArray, | ||
| MPILinearOperator, | ||
| Partition | ||
| ) | ||
|
|
||
|
|
||
| class SUMMAMatrixMult(MPILinearOperator): | ||
| def __init__( | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self, | ||
| A: NDArray, | ||
| N: int, | ||
| base_comm: MPI.Comm = MPI.COMM_WORLD, | ||
| dtype: DTypeLike = "float64", | ||
| ) -> None: | ||
| rank = base_comm.Get_rank() | ||
| size = base_comm.Get_size() | ||
|
|
||
| # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size | ||
| self._P_prime = int(math.ceil(math.sqrt(size))) | ||
| self._C = int(math.ceil(size / self._P_prime)) | ||
| assert self._P_prime * self._C >= size | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Compute this process's group and layer indices | ||
| self._group_id = rank % self._P_prime | ||
| self._layer_id = rank // self._P_prime | ||
|
|
||
| # Split communicators by layer (rows) and by group (columns) | ||
| self.base_comm = base_comm | ||
| self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id) | ||
| self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id) | ||
|
|
||
| self.dtype = np.dtype(dtype) | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.A = np.array(A, dtype=self.dtype, copy=False) | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM) | ||
| self.K = A.shape[1] | ||
| self.N = N | ||
|
|
||
| # Determine how many columns each group holds | ||
| block_cols = int(math.ceil(self.N / self._P_prime)) | ||
| local_col_start = self._group_id * block_cols | ||
| local_col_end = min(self.N, local_col_start + block_cols) | ||
| local_ncols = local_col_end - local_col_start | ||
|
|
||
| # Sum up the total number of input columns across all processes | ||
| total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM) | ||
| self.dims = (self.K, total_ncols) | ||
|
|
||
| # Recompute how many output columns each layer holds | ||
| layer_col_start = self._layer_id * block_cols | ||
| layer_col_end = min(self.N, layer_col_start + block_cols) | ||
| layer_ncols = layer_col_end - layer_col_start | ||
| total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM) | ||
|
|
||
| self.dimsd = (self.M, total_layer_cols) | ||
| shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) | ||
|
|
||
| super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) | ||
|
|
||
| def _matvec(self, x: DistributedArray) -> DistributedArray: | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ncp = get_module(x.engine) | ||
| if x.partition != Partition.SCATTER: | ||
| raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") | ||
| blk_cols = int(math.ceil(self.N / self._P_prime)) | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| col_start = self._group_id * blk_cols | ||
| col_end = min(self.N, col_start + blk_cols) | ||
| my_own_cols = max(0, col_end - col_start) | ||
| x = x.local_array.reshape((self.dims[0], my_own_cols)) | ||
| x = x.astype(self.dtype, copy=False) | ||
| B_block = self._layer_comm.bcast(x if self._group_id == self._layer_id else None, root=self._layer_id) | ||
| C_local = ncp.vstack( | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self._layer_comm.allgather( | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ncp.matmul(self.A, B_block) | ||
| ) | ||
| ) | ||
|
|
||
| layer_col_start = self._layer_id * blk_cols | ||
| layer_col_end = min(self.N, layer_col_start + blk_cols) | ||
| layer_ncols = max(0, layer_col_end - layer_col_start) | ||
| layer_col_lens = self.base_comm.allgather(layer_ncols) | ||
| mask = [i // self._P_prime for i in range(self.size)] | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| y = DistributedArray(global_shape= (self.M * self.dimsd[1]), | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| local_shapes=[(self.M * c) for c in layer_col_lens], | ||
| mask=mask, | ||
| #axis=1, | ||
| partition=Partition.SCATTER, | ||
| dtype=self.dtype) | ||
| y[:] = C_local.flatten() | ||
| return y | ||
|
|
||
| def _rmatvec(self, x: DistributedArray) -> DistributedArray: | ||
| ncp = get_module(x.engine) | ||
| if x.partition != Partition.SCATTER: | ||
| raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") | ||
|
|
||
| # Determine local column block for this layer | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| blk_cols = int(math.ceil(self.N / self._P_prime)) | ||
| layer_col_start = self._layer_id * blk_cols | ||
| layer_col_end = min(self.N, layer_col_start + blk_cols) | ||
| layer_ncols = layer_col_end - layer_col_start | ||
| layer_col_lens = self.base_comm.allgather(layer_ncols) | ||
| x = x.local_array.reshape((self.M, layer_ncols)) | ||
|
|
||
| # Determine local row block for this process group | ||
| blk_rows = int(math.ceil(self.M / self._P_prime)) | ||
| row_start = self._group_id * blk_rows | ||
| row_end = min(self.M, row_start + blk_rows) | ||
|
|
||
| B_tile = x[row_start:row_end, :].astype(self.dtype, copy=False) | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| A_local = self.A.T.conj() | ||
|
|
||
| m, b = A_local.shape | ||
| pad = (-m) % self._P_prime | ||
| r = (m + pad) // self._P_prime | ||
| A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=0) | ||
| A_batch = A_pad.reshape(self._P_prime, r, b) | ||
|
|
||
| # Perform local matmul and unpad | ||
| Y_batch = ncp.matmul(A_batch, B_tile) | ||
| Y_pad = Y_batch.reshape(r * self._P_prime, -1) | ||
| y_local = Y_pad[:m, :] | ||
| y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM) | ||
|
|
||
| mask = [i // self._P_prime for i in range(self.size)] | ||
| y = DistributedArray( | ||
| global_shape=(self.K * self.dimsd[1]), | ||
| local_shapes=[self.K * c for c in layer_col_lens], | ||
| mask=mask, | ||
| #axis=1 | ||
astroC86 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| partition=Partition.SCATTER, | ||
| dtype=self.dtype, | ||
| ) | ||
| y[:] = y_layer.flatten() | ||
| return y | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.