Skip to content

Commit 43ab4e0

Browse files
WestonJBmgates3
authored andcommitted
Added trmv, trsv
1 parent e458083 commit 43ab4e0

14 files changed

+1392
-0
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,9 @@ add_library(
380380
src/device_rotm.cc
381381
src/device_rotmg.cc
382382
src/device_trmm.cc
383+
src/device_trmv.cc
383384
src/device_trsm.cc
385+
src/device_trsv.cc
384386
src/device_utils.cc
385387
src/cublas_wrappers.cc
386388
src/rocblas_wrappers.cc

include/blas/device_blas.hh

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,88 @@ void symv(
498498
std::complex<double>* y, int64_t incy,
499499
blas::Queue& queue );
500500

501+
//------------------------------------------------------------------------------
502+
void trmv(
503+
blas::Layout layout,
504+
blas::Uplo uplo,
505+
blas::Op trans,
506+
blas::Diag diag,
507+
int64_t n,
508+
float const* A, int64_t lda,
509+
float* x, int64_t incx,
510+
blas::Queue& queue );
511+
512+
void trmv(
513+
blas::Layout layout,
514+
blas::Uplo uplo,
515+
blas::Op trans,
516+
blas::Diag diag,
517+
int64_t n,
518+
double const* A, int64_t lda,
519+
double* x, int64_t incx,
520+
blas::Queue& queue );
521+
522+
void trmv(
523+
blas::Layout layout,
524+
blas::Uplo uplo,
525+
blas::Op trans,
526+
blas::Diag diag,
527+
int64_t n,
528+
std::complex<float> const* A, int64_t lda,
529+
std::complex<float>* x, int64_t incx,
530+
blas::Queue& queue );
531+
532+
void trmv(
533+
blas::Layout layout,
534+
blas::Uplo uplo,
535+
blas::Op trans,
536+
blas::Diag diag,
537+
int64_t n,
538+
std::complex<double> const* A, int64_t lda,
539+
std::complex<double>* x, int64_t incx,
540+
blas::Queue& queue );
541+
542+
//------------------------------------------------------------------------------
543+
void trsv(
544+
blas::Layout layout,
545+
blas::Uplo uplo,
546+
blas::Op trans,
547+
blas::Diag diag,
548+
int64_t n,
549+
float const* A, int64_t lda,
550+
float* x, int64_t incx,
551+
blas::Queue& queue );
552+
553+
void trsv(
554+
blas::Layout layout,
555+
blas::Uplo uplo,
556+
blas::Op trans,
557+
blas::Diag diag,
558+
int64_t n,
559+
double const* A, int64_t lda,
560+
double* x, int64_t incx,
561+
blas::Queue& queue );
562+
563+
void trsv(
564+
blas::Layout layout,
565+
blas::Uplo uplo,
566+
blas::Op trans,
567+
blas::Diag diag,
568+
int64_t n,
569+
std::complex<float> const* A, int64_t lda,
570+
std::complex<float>* x, int64_t incx,
571+
blas::Queue& queue );
572+
573+
void trsv(
574+
blas::Layout layout,
575+
blas::Uplo uplo,
576+
blas::Op trans,
577+
blas::Diag diag,
578+
int64_t n,
579+
std::complex<double> const* A, int64_t lda,
580+
std::complex<double>* x, int64_t incx,
581+
blas::Queue& queue );
582+
501583
//==============================================================================
502584
// Level 3 BLAS
503585

src/cublas_wrappers.cc

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,162 @@ void symv(
11921192
(cuDoubleComplex*) dy, incdy ) );
11931193
}
11941194

1195+
//------------------------------------------------------------------------------
1196+
// trmv
1197+
//------------------------------------------------------------------------------
1198+
void trmv(
1199+
blas::Uplo uplo,
1200+
blas::Op trans,
1201+
blas::Diag diag,
1202+
int64_t n,
1203+
float const* dA, int64_t ldda,
1204+
float* dx, int64_t incdx,
1205+
blas::Queue& queue )
1206+
{
1207+
blas_dev_call(
1208+
cublasStrmv(
1209+
queue.handle(),
1210+
uplo2cublas( uplo ), op2cublas( trans ), diag2cublas( diag ),
1211+
n,
1212+
dA, ldda,
1213+
dx, incdx ) );
1214+
}
1215+
1216+
//------------------------------------------------------------------------------
1217+
void trmv(
1218+
blas::Uplo uplo,
1219+
blas::Op trans,
1220+
blas::Diag diag,
1221+
int64_t n,
1222+
double const* dA, int64_t ldda,
1223+
double* dx, int64_t incdx,
1224+
blas::Queue& queue )
1225+
{
1226+
blas_dev_call(
1227+
cublasDtrmv(
1228+
queue.handle(),
1229+
uplo2cublas( uplo ), op2cublas( trans ), diag2cublas( diag ),
1230+
n,
1231+
dA, ldda,
1232+
dx, incdx ) );
1233+
}
1234+
1235+
//------------------------------------------------------------------------------
1236+
void trmv(
1237+
blas::Uplo uplo,
1238+
blas::Op trans,
1239+
blas::Diag diag,
1240+
int64_t n,
1241+
std::complex<float> const* dA, int64_t ldda,
1242+
std::complex<float>* dx, int64_t incdx,
1243+
blas::Queue& queue )
1244+
{
1245+
blas_dev_call(
1246+
cublasCtrmv(
1247+
queue.handle(),
1248+
uplo2cublas( uplo ), op2cublas( trans ), diag2cublas( diag ),
1249+
n,
1250+
(cuComplex*) dA, ldda,
1251+
(cuComplex*) dx, incdx ) );
1252+
}
1253+
1254+
//------------------------------------------------------------------------------
1255+
void trmv(
1256+
blas::Uplo uplo,
1257+
blas::Op trans,
1258+
blas::Diag diag,
1259+
int64_t n,
1260+
std::complex<double> const* dA, int64_t ldda,
1261+
std::complex<double>* dx, int64_t incdx,
1262+
blas::Queue& queue )
1263+
{
1264+
blas_dev_call(
1265+
cublasZtrmv(
1266+
queue.handle(),
1267+
uplo2cublas( uplo ), op2cublas( trans ), diag2cublas( diag ),
1268+
n,
1269+
(cuDoubleComplex*) dA, ldda,
1270+
(cuDoubleComplex*) dx, incdx ) );
1271+
}
1272+
1273+
//------------------------------------------------------------------------------
1274+
// trsv
1275+
//------------------------------------------------------------------------------
1276+
void trsv(
1277+
blas::Uplo uplo,
1278+
blas::Op trans,
1279+
blas::Diag diag,
1280+
int64_t n,
1281+
float const* dA, int64_t ldda,
1282+
float* dx, int64_t incdx,
1283+
blas::Queue& queue )
1284+
{
1285+
blas_dev_call(
1286+
cublasStrsv(
1287+
queue.handle(),
1288+
uplo2cublas( uplo ), op2cublas( trans ), diag2cublas( diag ),
1289+
n,
1290+
dA, ldda,
1291+
dx, incdx ) );
1292+
}
1293+
1294+
//------------------------------------------------------------------------------
1295+
void trsv(
1296+
blas::Uplo uplo,
1297+
blas::Op trans,
1298+
blas::Diag diag,
1299+
int64_t n,
1300+
double const* dA, int64_t ldda,
1301+
double* dx, int64_t incdx,
1302+
blas::Queue& queue )
1303+
{
1304+
blas_dev_call(
1305+
cublasDtrsv(
1306+
queue.handle(),
1307+
uplo2cublas( uplo ), op2cublas( trans ), diag2cublas( diag ),
1308+
n,
1309+
dA, ldda,
1310+
dx, incdx ) );
1311+
}
1312+
1313+
//------------------------------------------------------------------------------
1314+
void trsv(
1315+
blas::Uplo uplo,
1316+
blas::Op trans,
1317+
blas::Diag diag,
1318+
int64_t n,
1319+
std::complex<float> const* dA, int64_t ldda,
1320+
std::complex<float>* dx, int64_t incdx,
1321+
blas::Queue& queue )
1322+
{
1323+
blas_dev_call(
1324+
cublasCtrsv(
1325+
queue.handle(),
1326+
uplo2cublas( uplo ), op2cublas( trans ), diag2cublas( diag ),
1327+
n,
1328+
(cuComplex*) dA, ldda,
1329+
(cuComplex*) dx, incdx ) );
1330+
}
1331+
1332+
//------------------------------------------------------------------------------
1333+
void trsv(
1334+
blas::Uplo uplo,
1335+
blas::Op trans,
1336+
blas::Diag diag,
1337+
int64_t n,
1338+
std::complex<double> const* dA, int64_t ldda,
1339+
std::complex<double>* dx, int64_t incdx,
1340+
blas::Queue& queue )
1341+
{
1342+
blas_dev_call(
1343+
cublasZtrsv(
1344+
queue.handle(),
1345+
uplo2cublas( uplo ), op2cublas( trans ), diag2cublas( diag ),
1346+
n,
1347+
(cuDoubleComplex*) dA, ldda,
1348+
(cuDoubleComplex*) dx, incdx ) );
1349+
}
1350+
11951351
//==============================================================================
11961352
// Level 3 BLAS - Device Interfaces
11971353

src/device_internal.hh

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,80 @@ void symv(
737737
std::complex<double>* dy, int64_t incdy,
738738
blas::Queue& queue );
739739

740+
//------------------------------------------------------------------------------
741+
void trmv(
742+
blas::Uplo uplo,
743+
blas::Op trans,
744+
blas::Diag diag,
745+
int64_t n,
746+
float const* A, int64_t lda,
747+
float* x, int64_t incx,
748+
blas::Queue& queue );
749+
750+
void trmv(
751+
blas::Uplo uplo,
752+
blas::Op trans,
753+
blas::Diag diag,
754+
int64_t n,
755+
double const* A, int64_t lda,
756+
double* x, int64_t incx,
757+
blas::Queue& queue );
758+
759+
void trmv(
760+
blas::Uplo uplo,
761+
blas::Op trans,
762+
blas::Diag diag,
763+
int64_t n,
764+
std::complex<float> const* A, int64_t lda,
765+
std::complex<float>* x, int64_t incx,
766+
blas::Queue& queue );
767+
768+
void trmv(
769+
blas::Uplo uplo,
770+
blas::Op trans,
771+
blas::Diag diag,
772+
int64_t n,
773+
std::complex<double> const* A, int64_t lda,
774+
std::complex<double>* x, int64_t incx,
775+
blas::Queue& queue );
776+
777+
//------------------------------------------------------------------------------
778+
void trsv(
779+
blas::Uplo uplo,
780+
blas::Op trans,
781+
blas::Diag diag,
782+
int64_t n,
783+
float const* A, int64_t lda,
784+
float* x, int64_t incx,
785+
blas::Queue& queue );
786+
787+
void trsv(
788+
blas::Uplo uplo,
789+
blas::Op trans,
790+
blas::Diag diag,
791+
int64_t n,
792+
double const* A, int64_t lda,
793+
double* x, int64_t incx,
794+
blas::Queue& queue );
795+
796+
void trsv(
797+
blas::Uplo uplo,
798+
blas::Op trans,
799+
blas::Diag diag,
800+
int64_t n,
801+
std::complex<float> const* A, int64_t lda,
802+
std::complex<float>* x, int64_t incx,
803+
blas::Queue& queue );
804+
805+
void trsv(
806+
blas::Uplo uplo,
807+
blas::Op trans,
808+
blas::Diag diag,
809+
int64_t n,
810+
std::complex<double> const* A, int64_t lda,
811+
std::complex<double>* x, int64_t incx,
812+
blas::Queue& queue );
813+
740814
//==============================================================================
741815
// Level 3 BLAS - Device Interfaces
742816

0 commit comments

Comments
 (0)