@@ -344,6 +344,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
344344 """
345345 code = """
346346
347+ bool is_float;
347348 int elemsize;
348349 float fbeta;
349350 double dbeta;
@@ -361,11 +362,23 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
361362 %(fail)s;
362363 }
363364
364- if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) { elemsize = 8; }
365- else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) { elemsize = 4;}
365+ if ((PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(x)s)->type_num)
366+ || (PyArray_DESCR(%(y)s)->type_num != PyArray_DESCR(%(A)s)->type_num))
367+ {
368+ PyErr_SetString(PyExc_TypeError, "GEMV: dtypes of A, x, y do not match");
369+ %(fail)s;
370+ }
371+ if (PyArray_DESCR(%(y)s)->type_num == NPY_DOUBLE) {
372+ is_float = 0;
373+ elemsize = 8;
374+ }
375+ else if (PyArray_DESCR(%(y)s)->type_num == NPY_FLOAT) {
376+ elemsize = 4;
377+ is_float = 1;
378+ }
366379 else {
367- PyErr_SetString(PyExc_NotImplementedError, "complex Gemv");
368380 %(fail)s;
381+ PyErr_SetString(PyExc_NotImplementedError, "GEMV: Inputs must be float or double");
369382 }
370383
371384 fbeta = dbeta = ((dtype_%(beta)s*)PyArray_DATA(%(beta)s))[0];
@@ -408,37 +421,40 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
408421 Py_INCREF(%(z)s);
409422 }
410423 }
424+
411425 {
412- char TRANS = 'T';
413- char NOTRANS = 'N';
414426 int NA0 = PyArray_DIMS(%(A)s)[0];
415427 int NA1 = PyArray_DIMS(%(A)s)[1];
416- /* This formula is needed in the case where A is actually a row or
417- * column matrix, because BLAS sometimes insists that the strides:
418- * - are not smaller than the number of elements in the array
419- * - are not 0.
420- */
421- int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
422- int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
423- int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424- int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425-
426- dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
427- dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
428- dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
429- // gemv expects pointers to the beginning of memory arrays,
430- // but numpy provides a pointer to the first element,
431- // so when the stride is negative, we need to get the last one.
432- if (Sx < 0)
433- x_data += (NA1 - 1) * Sx;
434- if (Sz < 0)
435- z_data += (NA0 - 1) * Sz;
436428
437429 if (NA0 * NA1)
438430 {
431+ // Non-empty A matrix
432+
433+ /* In the case where A is actually a row or column matrix,
434+ * the strides corresponding to the dummy dimension don't matter,
435+ * but BLAS requires these to be no smaller than the number of elements in the array.
436+ */
437+ int SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1;
438+ int SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0;
439+ int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
440+ int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
441+
442+ dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
443+ dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
444+ dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
445+
446+ // gemv expects pointers to the beginning of memory arrays,
447+ // but numpy provides a pointer to the first element,
448+ // so when the stride is negative, we need to get the last one.
449+ if (Sx < 0)
450+ x_data += (NA1 - 1) * Sx;
451+ if (Sz < 0)
452+ z_data += (NA0 - 1) * Sz;
453+
439454 if ( ((SA0 < 0) || (SA1 < 0)) && (abs(SA0) == 1 || (abs(SA1) == 1)) )
440455 {
441456 // We can treat the array A as C-or F-contiguous by changing the order of iteration
457+ // printf("GEMV: Iterating in reverse NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\ n", NA0, NA1, SA0, SA1);
442458 if (SA0 < 0){
443459 A_data += (NA0 -1) * SA0; // Jump to first row
444460 SA0 = -SA0; // Iterate over rows in reverse
@@ -452,27 +468,45 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
452468 } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)))
453469 {
454470 // Array isn't contiguous, we have to make a copy
455- // - if the copy is too long, maybe call vector/vector dot on
456- // each row instead
457- // printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\ n", SA0, SA1);
471+ // - if the copy is too long, maybe call vector/vector dot on each row instead
472+ // printf("GEMV: Making a copy NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\ n", NA0, NA1, SA0, SA1);
458473 npy_intp dims[2];
459474 dims[0] = NA0;
460475 dims[1] = NA1;
461-
462- PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(
463- %(A)s);
476+ PyArrayObject * A_copy = (PyArrayObject *) PyArray_Copy(%(A)s);
464477 if (!A_copy)
465478 %(fail)s
466479 Py_XDECREF(%(A)s);
467480 %(A)s = A_copy;
468- SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : ( NA1 + 1) ;
469- SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : ( NA0 + 1) ;
481+ SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : NA1;
482+ SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : NA0;
470483 A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
471484 }
485+ //else {printf("GEMV: Using the original array NA0=%%d, NA1=%%d, SA0=%%d, SA1=%%d\\ n", NA0, NA1, SA0, SA1);}
472486
473- if (SA0 == 1)
487+ if (NA0 == 1)
488+ {
489+ // Vector-vector dot product, it seems faster to avoid GEMV
490+ dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
491+
492+ if (is_float)
493+ {
494+ z_data[0] *= fbeta;
495+ z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1,
496+ (float*)x_data, &Sx);
497+ }
498+ else
499+ {
500+ z_data[0] *= dbeta;
501+ z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1,
502+ (double*)x_data, &Sx);
503+ }
504+ }
505+ else if (SA0 == 1)
474506 {
475- if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
507+ // F-contiguous
508+ char NOTRANS = 'N';
509+ if (is_float)
476510 {
477511 float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478512 sgemv_(&NOTRANS, &NA0, &NA1,
@@ -482,7 +516,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
482516 &fbeta,
483517 (float*)z_data, &Sz);
484518 }
485- else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
519+ else
486520 {
487521 double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
488522 dgemv_(&NOTRANS, &NA0, &NA1,
@@ -492,97 +526,39 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
492526 &dbeta,
493527 (double*)z_data, &Sz);
494528 }
495- else
496- {
497- PyErr_SetString(PyExc_AssertionError,
498- "neither float nor double dtype");
499- %(fail)s
500- }
501529 }
502530 else if (SA1 == 1)
503531 {
504- if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
532+ // C-contiguous
533+ char TRANS = 'T';
534+ if (is_float)
505535 {
506536 float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
507-
508- // Check for vector-vector dot (NA0 == 1). The code may work
509- // for SA1 != 1 as well, but has not been tested for this case,
510- // so SA1 == 1 is required for safety.
511- if (NA0 == 1 && SA1 == 1)
512- {
513- if (fbeta != 0.f) {
514- z_data[0] = fbeta*z_data[0];
515- } else {
516- z_data[0] = 0.f;
517- }
518- z_data[0] += alpha*sdot_(&NA1,
519- (float*)(A_data), &SA1,
520- (float*)x_data, &Sx);
521- }
522- else
523- {
524- sgemv_(&TRANS, &NA1, &NA0,
525- &alpha,
526- (float*)(A_data), &SA0,
527- (float*)x_data, &Sx,
528- &fbeta,
529- (float*)z_data, &Sz);
530- }
531- }
532- else if (PyArray_DESCR(%(A)s)->type_num == NPY_DOUBLE)
533- {
534- double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
535-
536- // Check for vector-vector dot (NA0 == 1). The code may work
537- // for SA1 != 1 as well, but has not been tested for this case,
538- // so SA1 == 1 is required for safety.
539- if (NA0 == 1 && SA1 == 1)
540- {
541- if (dbeta != 0.) {
542- z_data[0] = dbeta*z_data[0];
543- } else {
544- z_data[0] = 0.;
545- }
546- z_data[0] += alpha*ddot_(&NA1,
547- (double*)(A_data), &SA1,
548- (double*)x_data, &Sx);
549- }
550- else
551- {
552- dgemv_(&TRANS, &NA1, &NA0,
553- &alpha,
554- (double*)(A_data), &SA0,
555- (double*)x_data, &Sx,
556- &dbeta,
557- (double*)z_data, &Sz);
558- }
537+ sgemv_(&TRANS, &NA1, &NA0,
538+ &alpha,
539+ (float*)(A_data), &SA0,
540+ (float*)x_data, &Sx,
541+ &fbeta,
542+ (float*)z_data, &Sz);
559543 }
560544 else
561545 {
562- PyErr_SetString(PyExc_AssertionError,
563- "neither float nor double dtype");
564- %(fail)s
546+ double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
547+ dgemv_(&TRANS, &NA1, &NA0,
548+ &alpha,
549+ (double*)(A_data), &SA0,
550+ (double*)x_data, &Sx,
551+ &dbeta,
552+ (double*)z_data, &Sz);
565553 }
566554 }
567555 else
568556 {
569557 PyErr_SetString(PyExc_AssertionError,
570- "xx is a double-strided matrix, and should have been "
571- "copied into a memory-contiguous one.");
558+ "A is neither C nor F-contiguous, it should have been copied into a memory-contiguous array;");
572559 %(fail)s
573560 }
574561 }
575- else if (dbeta != 1.0)
576- {
577- // the matrix has at least one dim of length 0
578- // so we do this loop, which either iterates over 0 elements
579- // or else it does the right thing for length-0 A.
580- dtype_%(z)s * zptr = (dtype_%(z)s*)(PyArray_DATA(%(z)s));
581- for (int i = 0; i < NA0; ++i)
582- {
583- zptr[i * Sz] = (dbeta == 0.0 ? 0.0 : zptr[i * Sz] * dbeta);
584- }
585- }
586562 }
587563 """
588564 return code % locals ()
@@ -613,7 +589,7 @@ def c_code(self, node, name, inp, out, sub):
613589 return code
614590
615591 def c_code_cache_version (self ):
616- return (15 , blas_header_version (), check_force_gemv_init ())
592+ return (16 , blas_header_version (), check_force_gemv_init ())
617593
618594
619595cgemv_inplace = CGemv (inplace = True )
0 commit comments