Skip to content

Commit d9dae76

Browse files
committed
TST: Add exhaustive test for einsum specialized loops
This hopefully tests things well enough, at least some/most of the paths get triggered and led to errors without the previous float16 typing fixes. I manually confirmed that all paths that were *modified* in the previous commit actually get hit with float16 specialized loops. NOTE: This test may be a bit fragile with floating point roundoff errors, and can in parts be relaxed if this happens.
1 parent a4daaf5 commit d9dae76

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

numpy/core/tests/test_einsum.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import itertools
22

3+
import pytest
4+
35
import numpy as np
46
from numpy.testing import (
57
assert_, assert_equal, assert_array_equal, assert_almost_equal,
@@ -744,6 +746,52 @@ def test_einsum_all_contig_non_contig_output(self):
744746
np.einsum('ij,jk->ik', x, x, out=out)
745747
assert_array_equal(out.base, correct_base)
746748

749+
@pytest.mark.parametrize("dtype",
750+
np.typecodes["AllFloat"] + np.typecodes["AllInteger"])
751+
def test_different_paths(self, dtype):
752+
# Test originally added to cover broken float16 path: gh-20305
753+
# Likely most are covered elsewhere, at least partially.
754+
dtype = np.dtype(dtype)
755+
# Simple test, designed to excersize most specialized code paths,
756+
# note the +0.5 for floats. This makes sure we use a float value
757+
# where the results must be exact.
758+
arr = (np.arange(7) + 0.5).astype(dtype)
759+
scalar = np.array(2, dtype=dtype)
760+
761+
# contig -> scalar:
762+
res = np.einsum('i->', arr)
763+
assert res == arr.sum()
764+
# contig, contig -> contig:
765+
res = np.einsum('i,i->i', arr, arr)
766+
assert_array_equal(res, arr * arr)
767+
# noncontig, noncontig -> contig:
768+
res = np.einsum('i,i->i', arr.repeat(2)[::2], arr.repeat(2)[::2])
769+
assert_array_equal(res, arr * arr)
770+
# contig + contig -> scalar
771+
assert np.einsum('i,i->', arr, arr) == (arr * arr).sum()
772+
# contig + scalar -> contig (with out)
773+
out = np.ones(7, dtype=dtype)
774+
res = np.einsum('i,->i', arr, dtype.type(2), out=out)
775+
assert_array_equal(res, arr * dtype.type(2))
776+
# scalar + contig -> contig (with out)
777+
res = np.einsum(',i->i', scalar, arr)
778+
assert_array_equal(res, arr * dtype.type(2))
779+
# scalar + contig -> scalar
780+
res = np.einsum(',i->', scalar, arr)
781+
# Use einsum to compare to not have difference due to sum round-offs:
782+
assert res == np.einsum('i->', scalar * arr)
783+
# contig + scalar -> scalar
784+
res = np.einsum('i,->', arr, scalar)
785+
# Use einsum to compare to not have difference due to sum round-offs:
786+
assert res == np.einsum('i->', scalar * arr)
787+
# contig + contig + contig -> scalar
788+
arr = np.array([0.5, 0.5, 0.25, 4.5, 3.], dtype=dtype)
789+
res = np.einsum('i,i,i->', arr, arr, arr)
790+
assert_array_equal(res, (arr * arr * arr).sum())
791+
# four arrays:
792+
res = np.einsum('i,i,i,i->', arr, arr, arr, arr)
793+
assert_array_equal(res, (arr * arr * arr * arr).sum())
794+
747795
def test_small_boolean_arrays(self):
748796
# See gh-5946.
749797
# Use array of True embedded in False.

0 commit comments

Comments
 (0)