Skip to content

Commit 491564d

Browse files
authored
Merge pull request numpy#20310 from seberg/float16-einsum-fix
BUG: Fix float16 einsum fastpaths using wrong tempvar
2 parents 1995e2c + d9dae76 commit 491564d

File tree

2 files changed

+63
-15
lines changed

2 files changed

+63
-15
lines changed

numpy/core/src/multiarray/einsum_sumprod.c.src

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,13 @@ static NPY_GCC_OPT_3 void
337337
/**begin repeat2
338338
* #i = 0, 1, 2, 3#
339339
*/
340-
const @type@ b@i@ = @from@(data[@i@]);
341-
const @type@ c@i@ = @from@(data_out[@i@]);
340+
const @temptype@ b@i@ = @from@(data[@i@]);
341+
const @temptype@ c@i@ = @from@(data_out[@i@]);
342342
/**end repeat2**/
343343
/**begin repeat2
344344
* #i = 0, 1, 2, 3#
345345
*/
346-
const @type@ abc@i@ = scalar * b@i@ + c@i@;
346+
const @temptype@ abc@i@ = scalar * b@i@ + c@i@;
347347
/**end repeat2**/
348348
/**begin repeat2
349349
* #i = 0, 1, 2, 3#
@@ -353,8 +353,8 @@ static NPY_GCC_OPT_3 void
353353
}
354354
#endif // !NPY_DISABLE_OPTIMIZATION
355355
for (; count > 0; --count, ++data, ++data_out) {
356-
const @type@ b = @from@(*data);
357-
const @type@ c = @from@(*data_out);
356+
const @temptype@ b = @from@(*data);
357+
const @temptype@ c = @from@(*data_out);
358358
*data_out = @to@(scalar * b + c);
359359
}
360360
#endif // NPYV check for @type@
@@ -417,14 +417,14 @@ static void
417417
/**begin repeat2
418418
* #i = 0, 1, 2, 3#
419419
*/
420-
const @type@ a@i@ = @from@(data0[@i@]);
421-
const @type@ b@i@ = @from@(data1[@i@]);
422-
const @type@ c@i@ = @from@(data_out[@i@]);
420+
const @temptype@ a@i@ = @from@(data0[@i@]);
421+
const @temptype@ b@i@ = @from@(data1[@i@]);
422+
const @temptype@ c@i@ = @from@(data_out[@i@]);
423423
/**end repeat2**/
424424
/**begin repeat2
425425
* #i = 0, 1, 2, 3#
426426
*/
427-
const @type@ abc@i@ = a@i@ * b@i@ + c@i@;
427+
const @temptype@ abc@i@ = a@i@ * b@i@ + c@i@;
428428
/**end repeat2**/
429429
/**begin repeat2
430430
* #i = 0, 1, 2, 3#
@@ -434,9 +434,9 @@ static void
434434
}
435435
#endif // !NPY_DISABLE_OPTIMIZATION
436436
for (; count > 0; --count, ++data0, ++data1, ++data_out) {
437-
const @type@ a = @from@(*data0);
438-
const @type@ b = @from@(*data1);
439-
const @type@ c = @from@(*data_out);
437+
const @temptype@ a = @from@(*data0);
438+
const @temptype@ b = @from@(*data1);
439+
const @temptype@ c = @from@(*data_out);
440440
*data_out = @to@(a * b + c);
441441
}
442442
#endif // NPYV check for @type@
@@ -521,14 +521,14 @@ static NPY_GCC_OPT_3 void
521521
/**begin repeat2
522522
* #i = 0, 1, 2, 3#
523523
*/
524-
const @type@ ab@i@ = @from@(data0[@i@]) * @from@(data1[@i@]);
524+
const @temptype@ ab@i@ = @from@(data0[@i@]) * @from@(data1[@i@]);
525525
/**end repeat2**/
526526
accum += ab0 + ab1 + ab2 + ab3;
527527
}
528528
#endif // !NPY_DISABLE_OPTIMIZATION
529529
for (; count > 0; --count, ++data0, ++data1) {
530-
const @type@ a = @from@(*data0);
531-
const @type@ b = @from@(*data1);
530+
const @temptype@ a = @from@(*data0);
531+
const @temptype@ b = @from@(*data1);
532532
accum += a * b;
533533
}
534534
#endif // NPYV check for @type@

numpy/core/tests/test_einsum.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import itertools
22

3+
import pytest
4+
35
import numpy as np
46
from numpy.testing import (
57
assert_, assert_equal, assert_array_equal, assert_almost_equal,
@@ -744,6 +746,52 @@ def test_einsum_all_contig_non_contig_output(self):
744746
np.einsum('ij,jk->ik', x, x, out=out)
745747
assert_array_equal(out.base, correct_base)
746748

749+
@pytest.mark.parametrize("dtype",
750+
np.typecodes["AllFloat"] + np.typecodes["AllInteger"])
751+
def test_different_paths(self, dtype):
752+
# Test originally added to cover broken float16 path: gh-20305
753+
# Likely most are covered elsewhere, at least partially.
754+
dtype = np.dtype(dtype)
755+
# Simple test, designed to excersize most specialized code paths,
756+
# note the +0.5 for floats. This makes sure we use a float value
757+
# where the results must be exact.
758+
arr = (np.arange(7) + 0.5).astype(dtype)
759+
scalar = np.array(2, dtype=dtype)
760+
761+
# contig -> scalar:
762+
res = np.einsum('i->', arr)
763+
assert res == arr.sum()
764+
# contig, contig -> contig:
765+
res = np.einsum('i,i->i', arr, arr)
766+
assert_array_equal(res, arr * arr)
767+
# noncontig, noncontig -> contig:
768+
res = np.einsum('i,i->i', arr.repeat(2)[::2], arr.repeat(2)[::2])
769+
assert_array_equal(res, arr * arr)
770+
# contig + contig -> scalar
771+
assert np.einsum('i,i->', arr, arr) == (arr * arr).sum()
772+
# contig + scalar -> contig (with out)
773+
out = np.ones(7, dtype=dtype)
774+
res = np.einsum('i,->i', arr, dtype.type(2), out=out)
775+
assert_array_equal(res, arr * dtype.type(2))
776+
# scalar + contig -> contig (with out)
777+
res = np.einsum(',i->i', scalar, arr)
778+
assert_array_equal(res, arr * dtype.type(2))
779+
# scalar + contig -> scalar
780+
res = np.einsum(',i->', scalar, arr)
781+
# Use einsum to compare to not have difference due to sum round-offs:
782+
assert res == np.einsum('i->', scalar * arr)
783+
# contig + scalar -> scalar
784+
res = np.einsum('i,->', arr, scalar)
785+
# Use einsum to compare to not have difference due to sum round-offs:
786+
assert res == np.einsum('i->', scalar * arr)
787+
# contig + contig + contig -> scalar
788+
arr = np.array([0.5, 0.5, 0.25, 4.5, 3.], dtype=dtype)
789+
res = np.einsum('i,i,i->', arr, arr, arr)
790+
assert_array_equal(res, (arr * arr * arr).sum())
791+
# four arrays:
792+
res = np.einsum('i,i,i,i->', arr, arr, arr, arr)
793+
assert_array_equal(res, (arr * arr * arr * arr).sum())
794+
747795
def test_small_boolean_arrays(self):
748796
# See gh-5946.
749797
# Use array of True embedded in False.

0 commit comments

Comments
 (0)