Skip to content

Commit c695999

Browse files
Implementation of corrcoef
1 parent cb801da commit c695999

File tree

4 files changed

+115
-9
lines changed

4 files changed

+115
-9
lines changed

dpnp/dpnp_iface_statistics.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
3838
"""
3939

40+
import warnings
41+
4042
import dpctl.tensor as dpt
4143
import numpy
4244
from dpctl.tensor._numpy_helper import (
@@ -65,6 +67,7 @@
6567
"amin",
6668
"average",
6769
"bincount",
70+
"corrcoef",
6871
"correlate",
6972
"cov",
7073
"max",
@@ -403,6 +406,69 @@ def correlate(x1, x2, mode="valid"):
403406
return call_origin(numpy.correlate, x1, x2, mode=mode)
404407

405408

409+
def corrcoef(x, y=None, rowvar=True, bias=None, ddof=None, *, dtype=None):
410+
"""
411+
Return Pearson product-moment correlation coefficients.
412+
413+
For full documentation refer to :obj:`numpy.corrcoef`.
414+
415+
Parameters
416+
----------
417+
x : {dpnp.ndarray, usm_ndarray}
418+
A 1-D or 2-D array containing multiple variables and observations.
419+
Each row of `x` represents a variable, and each column a single
420+
observation of all those variables. Also see `rowvar` below.
421+
y : {dpnp.ndarray, usm_ndarray}, optional
422+
An additional set of variables and observations. `y` has the same
423+
shape as `x`.
424+
rowvar : {bool}, optional
425+
If `rowvar` is True (default), then each row represents a
426+
variable, with observations in the columns. Otherwise, the relationship
427+
is transposed: each column represents a variable, while the rows
428+
contain observations.
429+
bias : {None}, optional
430+
Has no effect, do not use.
431+
ddof : {None}, optional
432+
Has no effect, do not use.
433+
dtype : data-type, optional
434+
Data-type of the result.
435+
436+
Returns
437+
-------
438+
R : dpnp.ndarray
439+
The correlation coefficient matrix of the variables.
440+
441+
See Also
442+
--------
443+
:obj:`dpnp.cov` : Covariance matrix.
444+
"""
445+
if bias is not None or ddof is not None:
446+
warnings.warn(
447+
"bias and ddof have no effect and are deprecated",
448+
DeprecationWarning,
449+
stacklevel=2,
450+
)
451+
452+
out = dpnp.cov(x, y, rowvar, dtype=dtype)
453+
try:
454+
d = dpnp.diag(out)
455+
except ValueError:
456+
return out / out
457+
458+
stddev = dpnp.sqrt(d.real)
459+
out /= stddev[:, None]
460+
out /= stddev[None, :]
461+
462+
# Clip real and imaginary parts to [-1, 1]. This does not guarantee
463+
# abs(a[i,j]) <= 1 for complex arrays, but is the best we can do without
464+
# excessive work.
465+
dpnp.clip(out.real, -1, 1, out=out.real)
466+
if dpnp.iscomplexobj(out):
467+
dpnp.clip(out.imag, -1, 1, out=out.imag)
468+
469+
return out
470+
471+
406472
def cov(
407473
m,
408474
y=None,

tests/skipped_tests_gpu.tbl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,6 @@ tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_bo
320320
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit
321321
tests/third_party/cupy/random_tests/test_sample.py::TestRandomIntegers2::test_goodness_of_fit_2
322322

323-
tests/third_party/cupy/statistics_tests/test_correlation.py::TestCorrcoef::test_corrcoef
324-
tests/third_party/cupy/statistics_tests/test_correlation.py::TestCorrcoef::test_corrcoef_diag_exception
325-
tests/third_party/cupy/statistics_tests/test_correlation.py::TestCorrcoef::test_corrcoef_rowvar
326-
tests/third_party/cupy/statistics_tests/test_correlation.py::TestCorrcoef::test_corrcoef_y
327-
328323
tests/third_party/cupy/statistics_tests/test_order.py::TestOrder::test_percentile_defaults[linear]
329324
tests/third_party/cupy/statistics_tests/test_order.py::TestOrder::test_percentile_defaults[lower]
330325
tests/third_party/cupy/statistics_tests/test_order.py::TestOrder::test_percentile_defaults[higher]

tests/test_statistics.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,51 @@ def test_std_error(self):
573573
dpnp.std(ia, ddof="1")
574574

575575

576+
class TestCorrcoef:
577+
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
578+
@pytest.mark.parametrize("dtype", get_all_dtypes())
579+
@pytest.mark.parametrize("rowvar", [True, False])
580+
def test_corrcoef(self, dtype, rowvar):
581+
dp_array = dpnp.array([[0, 1, 2], [3, 4, 0]], dtype=dtype)
582+
np_array = dpnp.asnumpy(dp_array)
583+
584+
expected = numpy.corrcoef(np_array, rowvar=rowvar)
585+
result = dpnp.corrcoef(dp_array, rowvar=rowvar)
586+
587+
assert_dtype_allclose(result, expected)
588+
589+
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
590+
@pytest.mark.parametrize("shape", [(2, 0), (0, 2)])
591+
def test_corrcoef_empty(self, shape):
592+
dp_array = dpnp.empty(shape, dtype=dpnp.int64)
593+
np_array = dpnp.asnumpy(dp_array)
594+
595+
result = dpnp.corrcoef(dp_array)
596+
expected = numpy.corrcoef(np_array)
597+
assert_dtype_allclose(result, expected)
598+
599+
@pytest.mark.usefixtures("suppress_complex_warning")
600+
@pytest.mark.parametrize("dt_in", get_all_dtypes(no_bool=True))
601+
@pytest.mark.parametrize("dt_out", get_float_complex_dtypes())
602+
def test_corrcoef_dtype(self, dt_in, dt_out):
603+
dp_array = dpnp.array([[0, 1, 2], [3, 4, 0]], dtype=dt_in)
604+
np_array = dpnp.asnumpy(dp_array)
605+
606+
expected = numpy.corrcoef(np_array, dtype=dt_out)
607+
result = dpnp.corrcoef(dp_array, dtype=dt_out)
608+
assert expected.dtype == result.dtype
609+
assert_allclose(result, expected, rtol=1e-6)
610+
611+
@pytest.mark.usefixtures("suppress_divide_invalid_numpy_warnings")
612+
def test_corrcoef_scalar(self):
613+
dp_array = dpnp.array(5)
614+
np_array = dpnp.asnumpy(dp_array)
615+
616+
result = dpnp.corrcoef(dp_array)
617+
expected = numpy.corrcoef(np_array)
618+
assert_dtype_allclose(result, expected)
619+
620+
576621
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
577622
class TestBincount:
578623
@pytest.mark.parametrize(

tests/third_party/cupy/statistics_tests/test_correlation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,26 @@
1212

1313
class TestCorrcoef(unittest.TestCase):
1414
@testing.for_all_dtypes()
15-
@testing.numpy_cupy_allclose()
15+
@testing.numpy_cupy_allclose(type_check=False)
1616
def test_corrcoef(self, xp, dtype):
1717
a = testing.shaped_arange((2, 3), xp, dtype)
1818
return xp.corrcoef(a)
1919

2020
@testing.for_all_dtypes()
21-
@testing.numpy_cupy_allclose()
21+
@testing.numpy_cupy_allclose(type_check=False)
2222
def test_corrcoef_diag_exception(self, xp, dtype):
2323
a = testing.shaped_arange((1, 3), xp, dtype)
2424
return xp.corrcoef(a)
2525

2626
@testing.for_all_dtypes()
27-
@testing.numpy_cupy_allclose()
27+
@testing.numpy_cupy_allclose(type_check=False)
2828
def test_corrcoef_y(self, xp, dtype):
2929
a = testing.shaped_arange((2, 3), xp, dtype)
3030
y = testing.shaped_arange((2, 3), xp, dtype)
3131
return xp.corrcoef(a, y=y)
3232

3333
@testing.for_all_dtypes()
34-
@testing.numpy_cupy_allclose()
34+
@testing.numpy_cupy_allclose(type_check=False)
3535
def test_corrcoef_rowvar(self, xp, dtype):
3636
a = testing.shaped_arange((2, 3), xp, dtype)
3737
y = testing.shaped_arange((2, 3), xp, dtype)

0 commit comments

Comments
 (0)