1+ #include " cublas_v2.h"
2+ #include < cassert>
3+ #include < cuda_runtime.h>
4+ #include < math.h>
5+ #include < stdio.h>
6+ #include < stdlib.h>
7+ #include < vector>
8+
9+ int main () {
10+ constexpr size_t ROWS = 6 ;
11+ constexpr size_t COLUMNS = 5 ;
12+ constexpr float ALPHA = 1 .0f ;
13+ constexpr float BETA = 0 .0f ;
14+
15+ cublasHandle_t handle;
16+
17+ std::vector<float > hostA (ROWS * COLUMNS);
18+ std::vector<float > hostB (COLUMNS);
19+ std::vector<float > hostC (ROWS);
20+
21+ int index = 11 ;
22+ for (size_t i = 0 ; i < COLUMNS; i++) {
23+ for (size_t j = 0 ; j < ROWS; j++) {
24+ hostA[(i * ROWS) + j] = static_cast <float >(index++);
25+ }
26+ }
27+
28+ std::fill (std::begin (hostB), std::end (hostB), 1 .0f );
29+
30+ // hostA:
31+ // [11, 17, 23, 29, 35]
32+ // [12, 18, 24, 30, 36]
33+ // [13, 19, 25, 31, 37]
34+ // [14, 20, 26, 32, 38]
35+ // [15, 21, 27, 33, 39]
36+ // [16, 22, 28, 34, 40]
37+
38+ // hostB:
39+ // [1, 1, 1, 1, 1]
40+
41+ // hostC:
42+ // [0, 0, 0, 0, 0, 0]
43+
44+ float *deviceA = nullptr ;
45+ float *deviceB = nullptr ;
46+ float *deviceC = nullptr ;
47+
48+ cudaMalloc ((void **)&deviceA, ROWS * COLUMNS * sizeof (float ));
49+ cudaMalloc ((void **)&deviceB, COLUMNS * sizeof (float ));
50+ cudaMalloc ((void **)&deviceC, ROWS * sizeof (float ));
51+
52+ cublasCreate (&handle);
53+
54+ cublasSetMatrix (ROWS, COLUMNS, sizeof (float ), hostA.data (), ROWS, deviceA,
55+ ROWS);
56+ cublasSetVector (COLUMNS, sizeof (float ), hostB.data (), 1 , deviceB, 1 );
57+ cublasSetVector (ROWS, sizeof (float ), hostC.data (), 1 , deviceC, 1 );
58+ cublasSgemv (handle, CUBLAS_OP_N, ROWS, COLUMNS, &ALPHA, deviceA, ROWS,
59+ deviceB, 1 , &BETA, deviceC, 1 );
60+ cublasGetVector (ROWS, sizeof (float ), deviceC, 1 , hostC.data (), 1 );
61+
62+ cudaFree (deviceA);
63+ cudaFree (deviceB);
64+ cudaFree (deviceC);
65+
66+ assert (hostC[0 ] == 115 ); // [11, 17, 23, 29, 35] [1]
67+ assert (hostC[1 ] == 120 ); // [12, 18, 24, 30, 36] [1]
68+ assert (hostC[2 ] == 125 ); // [13, 19, 25, 31, 37] * [1]
69+ assert (hostC[3 ] == 130 ); // [14, 20, 26, 32, 38] [1]
70+ assert (hostC[4 ] == 135 ); // [15, 21, 27, 33, 39] [1]
71+ assert (hostC[5 ] == 140 ); // [16, 22, 28, 34, 40]
72+
73+ cublasDestroy (handle);
74+ }
0 commit comments