Skip to content

Commit 31fbc72

Browse files
authored
Merge pull request #174 from mrava87/patch-ncclinit
feat: modify initialize_nccl_comm to handle nodes with more gpus than ranks
2 parents 27f3687 + 6276d4e commit 31fbc72

File tree

4 files changed

+16
-2
lines changed

4 files changed

+16
-2
lines changed

pylops_mpi/basicoperators/Laplacian.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,14 @@ def _calc_l2op(self):
113113
if ax == 0:
114114
l2op += weight * MPISecondDerivative(dims=self.dims,
115115
sampling=samp,
116+
kind=self.kind,
116117
edge=self.edge,
117118
dtype=self.dtype)
118119
else:
119120
l2op += weight * MPIBlockDiag(ops=[SecondDerivative(dims=local_dims,
120121
axis=ax,
121122
sampling=samp,
123+
kind=self.kind,
122124
edge=self.edge,
123125
dtype=self.dtype)])
124126
return l2op

pylops_mpi/optimization/cls_basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def _print_step(self, x: Union[DistributedArray, StackedDistributedArray]) -> No
5151
print(msg)
5252
sys.stdout.flush()
5353

54+
def memory_usage(self) -> None:
55+
pass
56+
5457
def setup(
5558
self,
5659
y: Union[DistributedArray, StackedDistributedArray],
@@ -299,6 +302,9 @@ def _print_step(self, x: Union[DistributedArray, StackedDistributedArray]) -> No
299302
print(msg)
300303
sys.stdout.flush()
301304

305+
def memory_usage(self) -> None:
306+
pass
307+
302308
def setup(self,
303309
y: Union[DistributedArray, StackedDistributedArray],
304310
x0: Union[DistributedArray, StackedDistributedArray],

pylops_mpi/utils/_nccl.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,14 @@ def initialize_nccl_comm() -> nccl.NcclCommunicator:
107107
comm = MPI.COMM_WORLD
108108
rank = comm.Get_rank()
109109
size = comm.Get_size()
110+
111+
# Create a communicator for ranks on the same node
112+
node_comm = comm.Split_type(MPI.COMM_TYPE_SHARED)
113+
size_node = node_comm.Get_size()
114+
110115
device_id = int(
111116
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
112-
or rank % cp.cuda.runtime.getDeviceCount()
117+
or (rank % size_node) % cp.cuda.runtime.getDeviceCount()
113118
)
114119
cp.cuda.Device(device_id).use()
115120

tests/test_matrixmult.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
test_params = [
4040
pytest.param(37, 37, 37, "float64", id="f32_37_37_37"),
4141
pytest.param(50, 30, 40, "float64", id="f64_50_30_40"),
42-
pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"),
42+
# temporarely removed as sometimes crashed CI... to be investigated
43+
# pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"),
4344
pytest.param(3, 4, 5, "float32", id="f32_3_4_5"),
4445
pytest.param(1, 2, 1, "float64", id="f64_1_2_1",),
4546
pytest.param(2, 1, 3, "float32", id="f32_2_1_3",),

0 commit comments

Comments
 (0)