@@ -423,6 +423,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
423423 int Sz = PyArray_STRIDES(%(z)s)[0] / elemsize;
424424 int Sx = PyArray_STRIDES(%(x)s)[0] / elemsize;
425425
426+ dtype_%(A)s* A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
426427 dtype_%(x)s* x_data = (dtype_%(x)s*) PyArray_DATA(%(x)s);
427428 dtype_%(z)s* z_data = (dtype_%(z)s*) PyArray_DATA(%(z)s);
428429 // gemv expects pointers to the beginning of memory arrays,
@@ -435,17 +436,28 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
435436
436437 if (NA0 * NA1)
437438 {
438- // If A is neither C- nor F-contiguous, we make a copy.
439- // TODO:
440- // - if one stride is equal to "- elemsize", we can still call
441- // gemv on reversed matrix and vectors
442- // - if the copy is too long, maybe call vector/vector dot on
443- // each row instead
444- if ((PyArray_STRIDES(%(A)s)[0] < 0)
445- || (PyArray_STRIDES(%(A)s)[1] < 0)
446- || ((PyArray_STRIDES(%(A)s)[0] != elemsize)
447- && (PyArray_STRIDES(%(A)s)[1] != elemsize)))
439+ if (((SA0 < 0) || (SA1 < 0))
440+ && (abs(SA0) == 1 || (abs(SA1) == 1))
441+ )
448442 {
443+ // We can treat the array A as C-or F-contiguous by changing the order of iteration
444+
445+ if (SA0 < 0){
446+ A_data += (NA0 -1) * SA0; // Jump to first row
447+ SA0 = -SA0; // Pretend row strides is positive
448+ Sz = -Sz; // Iterate over y in reverse;
449+ }
450+ if (SA1 < 0){
451+ A_data += (NA1 -1) * SA1; // Jump to first column
452+ SA1 = -SA1; // Pretend column strides is positive
453+ Sx = -Sx; // Iterate over x in reverse;
454+ }
455+
456+ } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1))) {
457+ // Array isn't contiguous, we have to make a copy
458+ // - if the copy is too long, maybe call vector/vector dot on
459+ // each row instead
460+ printf("GEMV: Making a copy SA0=%%d, SA1=%%d\\ n", SA0, SA1);
449461 npy_intp dims[2];
450462 dims[0] = NA0;
451463 dims[1] = NA1;
@@ -458,16 +470,17 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
458470 %(A)s = A_copy;
459471 SA0 = (NA0 > 1) ? (PyArray_STRIDES(%(A)s)[0] / elemsize) : (NA1 + 1);
460472 SA1 = (NA1 > 1) ? (PyArray_STRIDES(%(A)s)[1] / elemsize) : (NA0 + 1);
473+ A_data = (dtype_%(A)s*) PyArray_DATA(%(A)s);
461474 }
462475
463- if (PyArray_STRIDES(%(A)s)[0] == elemsize )
476+ if (SA0 == 1 )
464477 {
465478 if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
466479 {
467480 float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
468481 sgemv_(&NOTRANS, &NA0, &NA1,
469482 &alpha,
470- (float*)(PyArray_DATA(%(A)s) ), &SA1,
483+ (float*)(A_data ), &SA1,
471484 (float*)x_data, &Sx,
472485 &fbeta,
473486 (float*)z_data, &Sz);
@@ -477,7 +490,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
477490 double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0];
478491 dgemv_(&NOTRANS, &NA0, &NA1,
479492 &alpha,
480- (double*)(PyArray_DATA(%(A)s) ), &SA1,
493+ (double*)(A_data ), &SA1,
481494 (double*)x_data, &Sx,
482495 &dbeta,
483496 (double*)z_data, &Sz);
@@ -489,7 +502,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
489502 %(fail)s
490503 }
491504 }
492- else if (PyArray_STRIDES(%(A)s)[1] == elemsize )
505+ else if (SA1 == 1 )
493506 {
494507 if (PyArray_DESCR(%(A)s)->type_num == NPY_FLOAT)
495508 {
@@ -506,14 +519,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
506519 z_data[0] = 0.f;
507520 }
508521 z_data[0] += alpha*sdot_(&NA1,
509- (float*)(PyArray_DATA(%(A)s) ), &SA1,
522+ (float*)(A_data ), &SA1,
510523 (float*)x_data, &Sx);
511524 }
512525 else
513526 {
514527 sgemv_(&TRANS, &NA1, &NA0,
515528 &alpha,
516- (float*)(PyArray_DATA(%(A)s) ), &SA0,
529+ (float*)(A_data ), &SA0,
517530 (float*)x_data, &Sx,
518531 &fbeta,
519532 (float*)z_data, &Sz);
@@ -534,14 +547,14 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
534547 z_data[0] = 0.;
535548 }
536549 z_data[0] += alpha*ddot_(&NA1,
537- (double*)(PyArray_DATA(%(A)s) ), &SA1,
550+ (double*)(A_data ), &SA1,
538551 (double*)x_data, &Sx);
539552 }
540553 else
541554 {
542555 dgemv_(&TRANS, &NA1, &NA0,
543556 &alpha,
544- (double*)(PyArray_DATA(%(A)s) ), &SA0,
557+ (double*)(A_data ), &SA0,
545558 (double*)x_data, &Sx,
546559 &dbeta,
547560 (double*)z_data, &Sz);
@@ -603,7 +616,7 @@ def c_code(self, node, name, inp, out, sub):
603616 return code
604617
605618 def c_code_cache_version (self ):
606- return (14 , blas_header_version (), check_force_gemv_init ())
619+ return (15 , blas_header_version (), check_force_gemv_init ())
607620
608621
609622cgemv_inplace = CGemv (inplace = True )
0 commit comments