Skip to content

Commit 181f870

Browse files
authored
add oneMKL for complex loop of add, subtract, multiply, divide (#102)
1 parent 6bab871 commit 181f870

File tree

2 files changed

+102
-71
lines changed

2 files changed

+102
-71
lines changed

CHANGELOG.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7-
## [dev] (MM/DD/YYYY)
7+
## [dev] - YYYY-MM-DD
88

99
### Added
1010
* Added mkl implementation for floating point data-types of `exp2`, `log2`, `fabs`, `copysign`, `nextafter`, `fmax`, `fmin` and `remainder` functions [gh-81](https://github.com/IntelPython/mkl_umath/pull/81)
1111
* Added mkl implementation for complex data-types of `conjugate` and `absolute` functions [gh-86](https://github.com/IntelPython/mkl_umath/pull/86)
1212
* Enabled support of Python 3.13 [gh-101](https://github.com/IntelPython/mkl_umath/pull/101)
13+
* Added mkl implementation for complex data-types of `add`, `subtract`, `multiply` and `divide` functions [gh-88](https://github.com/IntelPython/mkl_umath/pull/88)
1314

14-
## [0.2.0] (06/03/2025)
15+
## [0.2.0] - 2025-06-03
1516
This release updates `mkl_umath` to be aligned with both numpy-1.26.x and numpy-2.x.x.
1617

1718
### Added
@@ -26,20 +27,20 @@ This release updates `mkl_umath` to be aligned with both numpy-1.26.x and numpy-
2627
* Fixed a bug for `mkl_umath.is_patched` function [gh-66](https://github.com/IntelPython/mkl_umath/pull/66)
2728

2829

29-
## [0.1.5] (04/09/2025)
30+
## [0.1.5] - 2025-04-09
3031

3132
### Fixed
3233
* Fixed failures to import `mkl_umath` from virtual environment on Linux
3334

34-
## [0.1.4] (04/09/2025)
35+
## [0.1.4] - 2025-04-09
3536

3637
### Added
3738
* Added support for `mkl_umath` out-of-the-box in virtual environments on Windows
3839

3940
### Fixed
4041
* Fixed a bug in in-place addition with negative zeros
4142

42-
## [0.1.2] (10/11/2024)
43+
## [0.1.2] - 2024-10-11
4344

4445
### Added
4546
* Added support for building with NumPy 2.0 and older

mkl_umath/src/mkl_umath_loops.c.src

Lines changed: 96 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,11 @@ pairwise_sum_@TYPE@(char *a, npy_intp n, npy_intp stride)
318318
void
319319
mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func))
320320
{
321+
const int disjoint_or_same1 = DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@));
322+
const int disjoint_or_same2 = DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@));
323+
321324
if (IS_BINARY_CONT(@type@, @type@)) {
322-
if (dimensions[0] > VML_ASM_THRESHOLD &&
323-
DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@)) &&
324-
DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@))) {
325+
if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same1 && disjoint_or_same2) {
325326
CHUNKED_VML_CALL3(v@s@@VML@, dimensions[0], @type@, args[0], args[1], args[2]);
326327
/* v@s@@VML@(dimensions[0], (@type@*) args[0], (@type@*) args[1], (@type@*) args[2]); */
327328
}
@@ -371,8 +372,7 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp
371372
}
372373
}
373374
else if (IS_BINARY_CONT_S1(@type@, @type@)) {
374-
if (dimensions[0] > VML_ASM_THRESHOLD &&
375-
DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@))) {
375+
if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same2) {
376376
CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[1], args[2], @[email protected], *(@type@*)args[0], 0.0, 1.0);
377377
/* v@s@LinearFrac(dimensions[0], (@type@*) args[1], (@type@*) args[1], @[email protected], *(@type@*)args[0], 0.0, 1.0, (@type@*) args[2]); */
378378
}
@@ -412,8 +412,7 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp
412412
}
413413
}
414414
else if (IS_BINARY_CONT_S2(@type@, @type@)) {
415-
if (dimensions[0] > VML_ASM_THRESHOLD &&
416-
DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@))) {
415+
if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same1) {
417416
CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[0], args[2], 1.0, @OP@(*(@type@*)args[1]), 0.0, 1.0);
418417
/* v@s@LinearFrac(dimensions[0], (@type@*) args[0], (@type@*) args[0], 1.0, @OP@(*(@type@*)args[1]), 0.0, 1.0, (@type@*) args[2]); */
419418
}
@@ -478,10 +477,11 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp
478477
void
479478
mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func))
480479
{
480+
const int disjoint_or_same1 = DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@));
481+
const int disjoint_or_same2 = DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@));
482+
481483
if (IS_BINARY_CONT(@type@, @type@)) {
482-
if (dimensions[0] > VML_ASM_THRESHOLD &&
483-
DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@)) &&
484-
DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@))) {
484+
if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same1 && disjoint_or_same2) {
485485
CHUNKED_VML_CALL3(v@s@Mul, dimensions[0], @type@, args[0], args[1], args[2]);
486486
/* v@s@Mul(dimensions[0], (@type@*) args[0], (@type@*) args[1], (@type@*) args[2]); */
487487
}
@@ -531,8 +531,7 @@ mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_int
531531
}
532532
}
533533
else if (IS_BINARY_CONT_S1(@type@, @type@)) {
534-
if (dimensions[0] > VML_ASM_THRESHOLD &&
535-
DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@))) {
534+
if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same2) {
536535
CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[1], args[2], *(@type@*)args[0], 0.0, 0.0, 1.0);
537536
/* v@s@LinearFrac(dimensions[0], (@type@*) args[1], (@type@*) args[1], *(@type@*)args[0], 0.0, 0.0, 1.0, (@type@*) args[2]); */
538537
}
@@ -572,8 +571,7 @@ mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_int
572571
}
573572
}
574573
else if (IS_BINARY_CONT_S2(@type@, @type@)) {
575-
if (dimensions[0] > VML_ASM_THRESHOLD &&
576-
DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@))) {
574+
if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same1) {
577575
CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[0], args[2], *(@type@*)args[1], 0.0, 0.0, 1.0);
578576
/* v@s@LinearFrac(dimensions[0], (@type@*) args[0], (@type@*) args[0], *(@type@*)args[1], 0.0, 0.0, 1.0, (@type@*) args[2]); */
579577
}
@@ -630,10 +628,11 @@ mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_int
630628
void
631629
mkl_umath_@TYPE@_divide(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func))
632630
{
631+
const int disjoint_or_same1 = DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@));
632+
const int disjoint_or_same2 = DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@));
633+
633634
if (IS_BINARY_CONT(@type@, @type@)) {
634-
if (dimensions[0] > VML_D_THRESHOLD &&
635-
DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@)) &&
636-
DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@))) {
635+
if (dimensions[0] > VML_D_THRESHOLD && disjoint_or_same1 && disjoint_or_same2) {
637636
CHUNKED_VML_CALL3(v@s@Div, dimensions[0], @type@, args[0], args[1], args[2]);
638637
/* v@s@Div(dimensions[0], (@type@*) args[0], (@type@*) args[1], (@type@*) args[2]); */
639638
}
@@ -1365,83 +1364,114 @@ pairwise_sum_@TYPE@(@ftype@ *rr, @ftype@ * ri, char * a, npy_intp n, npy_intp st
13651364
}
13661365
}
13671366

1368-
/* TODO: USE MKL */
13691367
/**begin repeat1
13701368
* #kind = add, subtract#
13711369
* #OP = +, -#
13721370
* #PW = 1, 0#
1371+
* #VML = Add, Sub#
13731372
*/
13741373
void
13751374
mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func))
13761375
{
1377-
if (IS_BINARY_REDUCE && @PW@) {
1378-
npy_intp n = dimensions[0];
1379-
@ftype@ * or = ((@ftype@ *)args[0]);
1380-
@ftype@ * oi = ((@ftype@ *)args[0]) + 1;
1381-
@ftype@ rr, ri;
1376+
const int contig = IS_BINARY_CONT(@type@, @type@);
1377+
const int disjoint_or_same1 = DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@));
1378+
const int disjoint_or_same2 = DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@));
1379+
const int can_vectorize = contig && disjoint_or_same1 && disjoint_or_same2;
13821380

1383-
pairwise_sum_@TYPE@(&rr, &ri, args[1], n * 2, steps[1] / 2);
1384-
*or @OP@= rr;
1385-
*oi @OP@= ri;
1386-
return;
1381+
if (can_vectorize && dimensions[0] > VML_ASM_THRESHOLD) {
1382+
CHUNKED_VML_CALL3(v@s@@VML@, dimensions[0], @type@, args[0], args[1], args[2]);
1383+
/* v@s@@VML@(dimensions[0], (@type@*) args[0], (@type@*) args[1], (@type@*) args[2]); */
13871384
}
1388-
else {
1389-
BINARY_LOOP {
1390-
const @ftype@ in1r = ((@ftype@ *)ip1)[0];
1391-
const @ftype@ in1i = ((@ftype@ *)ip1)[1];
1392-
const @ftype@ in2r = ((@ftype@ *)ip2)[0];
1393-
const @ftype@ in2i = ((@ftype@ *)ip2)[1];
1394-
((@ftype@ *)op1)[0] = in1r @OP@ in2r;
1395-
((@ftype@ *)op1)[1] = in1i @OP@ in2i;
1385+
else {
1386+
if (IS_BINARY_REDUCE && @PW@) {
1387+
npy_intp n = dimensions[0];
1388+
@ftype@ * or = ((@ftype@ *)args[0]);
1389+
@ftype@ * oi = ((@ftype@ *)args[0]) + 1;
1390+
@ftype@ rr, ri;
1391+
1392+
pairwise_sum_@TYPE@(&rr, &ri, args[1], n * 2, steps[1] / 2);
1393+
*or @OP@= rr;
1394+
*oi @OP@= ri;
1395+
return;
1396+
}
1397+
else {
1398+
BINARY_LOOP {
1399+
const @ftype@ in1r = ((@ftype@ *)ip1)[0];
1400+
const @ftype@ in1i = ((@ftype@ *)ip1)[1];
1401+
const @ftype@ in2r = ((@ftype@ *)ip2)[0];
1402+
const @ftype@ in2i = ((@ftype@ *)ip2)[1];
1403+
((@ftype@ *)op1)[0] = in1r @OP@ in2r;
1404+
((@ftype@ *)op1)[1] = in1i @OP@ in2i;
1405+
}
13961406
}
13971407
}
13981408
}
13991409
/**end repeat1**/
14001410

1401-
/* TODO: USE MKL */
14021411
void
14031412
mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func))
14041413
{
1405-
BINARY_LOOP {
1406-
const @ftype@ in1r = ((@ftype@ *)ip1)[0];
1407-
const @ftype@ in1i = ((@ftype@ *)ip1)[1];
1408-
const @ftype@ in2r = ((@ftype@ *)ip2)[0];
1409-
const @ftype@ in2i = ((@ftype@ *)ip2)[1];
1410-
((@ftype@ *)op1)[0] = in1r*in2r - in1i*in2i;
1411-
((@ftype@ *)op1)[1] = in1r*in2i + in1i*in2r;
1414+
const int contig = IS_BINARY_CONT(@type@, @type@);
1415+
const int disjoint_or_same1 = DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@));
1416+
const int disjoint_or_same2 = DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@));
1417+
const int can_vectorize = contig && disjoint_or_same1 && disjoint_or_same2;
1418+
1419+
if (can_vectorize && dimensions[0] > VML_ASM_THRESHOLD) {
1420+
CHUNKED_VML_CALL3(v@s@Mul, dimensions[0], @type@, args[0], args[1], args[2]);
1421+
/* v@s@Mul(dimensions[0], (@type@*) args[0], (@type@*) args[1], (@type@*) args[2]); */
1422+
}
1423+
else {
1424+
BINARY_LOOP {
1425+
const @ftype@ in1r = ((@ftype@ *)ip1)[0];
1426+
const @ftype@ in1i = ((@ftype@ *)ip1)[1];
1427+
const @ftype@ in2r = ((@ftype@ *)ip2)[0];
1428+
const @ftype@ in2i = ((@ftype@ *)ip2)[1];
1429+
((@ftype@ *)op1)[0] = in1r*in2r - in1i*in2i;
1430+
((@ftype@ *)op1)[1] = in1r*in2i + in1i*in2r;
1431+
}
14121432
}
14131433
}
14141434

1415-
/* TODO: USE MKL */
14161435
void
14171436
mkl_umath_@TYPE@_divide(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func))
14181437
{
1419-
BINARY_LOOP {
1420-
const @ftype@ in1r = ((@ftype@ *)ip1)[0];
1421-
const @ftype@ in1i = ((@ftype@ *)ip1)[1];
1422-
const @ftype@ in2r = ((@ftype@ *)ip2)[0];
1423-
const @ftype@ in2i = ((@ftype@ *)ip2)[1];
1424-
const @ftype@ in2r_abs = fabs@c@(in2r);
1425-
const @ftype@ in2i_abs = fabs@c@(in2i);
1426-
if (in2r_abs >= in2i_abs) {
1427-
if (in2r_abs == 0 && in2i_abs == 0) {
1428-
/* divide by zero should yield a complex inf or nan */
1429-
((@ftype@ *)op1)[0] = in1r/in2r_abs;
1430-
((@ftype@ *)op1)[1] = in1i/in2i_abs;
1438+
const int contig = IS_BINARY_CONT(@type@, @type@);
1439+
const int disjoint_or_same1 = DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@));
1440+
const int disjoint_or_same2 = DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@));
1441+
const int can_vectorize = contig && disjoint_or_same1 && disjoint_or_same2;
1442+
1443+
if (can_vectorize && dimensions[0] > VML_D_THRESHOLD) {
1444+
CHUNKED_VML_CALL3(v@s@Div, dimensions[0], @type@, args[0], args[1], args[2]);
1445+
/* v@s@Div(dimensions[0], (@type@*) args[0], (@type@*) args[1], (@type@*) args[2]); */
1446+
}
1447+
else {
1448+
BINARY_LOOP {
1449+
const @ftype@ in1r = ((@ftype@ *)ip1)[0];
1450+
const @ftype@ in1i = ((@ftype@ *)ip1)[1];
1451+
const @ftype@ in2r = ((@ftype@ *)ip2)[0];
1452+
const @ftype@ in2i = ((@ftype@ *)ip2)[1];
1453+
const @ftype@ in2r_abs = fabs@c@(in2r);
1454+
const @ftype@ in2i_abs = fabs@c@(in2i);
1455+
if (in2r_abs >= in2i_abs) {
1456+
if (in2r_abs == 0 && in2i_abs == 0) {
1457+
/* divide by zero should yield a complex inf or nan */
1458+
((@ftype@ *)op1)[0] = in1r/in2r_abs;
1459+
((@ftype@ *)op1)[1] = in1i/in2i_abs;
1460+
}
1461+
else {
1462+
const @ftype@ rat = in2i/in2r;
1463+
const @ftype@ scl = 1.0@c@/(in2r + in2i*rat);
1464+
((@ftype@ *)op1)[0] = (in1r + in1i*rat)*scl;
1465+
((@ftype@ *)op1)[1] = (in1i - in1r*rat)*scl;
1466+
}
14311467
}
14321468
else {
1433-
const @ftype@ rat = in2i/in2r;
1434-
const @ftype@ scl = 1.0@c@/(in2r + in2i*rat);
1435-
((@ftype@ *)op1)[0] = (in1r + in1i*rat)*scl;
1436-
((@ftype@ *)op1)[1] = (in1i - in1r*rat)*scl;
1469+
const @ftype@ rat = in2r/in2i;
1470+
const @ftype@ scl = 1.0@c@/(in2i + in2r*rat);
1471+
((@ftype@ *)op1)[0] = (in1r*rat + in1i)*scl;
1472+
((@ftype@ *)op1)[1] = (in1i*rat - in1r)*scl;
14371473
}
14381474
}
1439-
else {
1440-
const @ftype@ rat = in2r/in2i;
1441-
const @ftype@ scl = 1.0@c@/(in2i + in2r*rat);
1442-
((@ftype@ *)op1)[0] = (in1r*rat + in1i)*scl;
1443-
((@ftype@ *)op1)[1] = (in1i*rat - in1r)*scl;
1444-
}
14451475
}
14461476
}
14471477

0 commit comments

Comments
 (0)