diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index b858077f..431ddb6e 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -818,21 +818,22 @@ def norm(self, ord: Optional[int] = None): ord : :obj:`int`, optional Order of the norm. """ - norms = np.array([distarray.norm(ord) for distarray in self.distarrays]) + ncp = get_module(self.distarrays[0].engine) + norms = ncp.array([distarray.norm(ord) for distarray in self.distarrays]) ord = 2 if ord is None else ord if ord in ['fro', 'nuc']: raise ValueError(f"norm-{ord} not possible for vectors") elif ord == 0: # Count non-zero then sum reduction - norm = np.sum(norms) - elif ord == np.inf: + norm = ncp.sum(norms) + elif ord == ncp.inf: # Calculate max followed by max reduction - norm = np.max(norms) - elif ord == -np.inf: + norm = ncp.max(norms) + elif ord == -ncp.inf: # Calculate min followed by max reduction - norm = np.min(norms) + norm = ncp.min(norms) else: - norm = np.power(np.sum(np.power(norms, ord)), 1. / ord) + norm = ncp.power(ncp.sum(ncp.power(norms, ord)), 1. / ord) return norm def conj(self):