Skip to content

Commit 8e9bc4d

Browse files
author
Timmy
committed
dtrsm right side
1 parent 5ee9e5f commit 8e9bc4d

File tree

1 file changed

+141
-5
lines changed

1 file changed

+141
-5
lines changed

src/library/blas/xtrsm.cc

Lines changed: 141 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1106,9 +1106,6 @@ static clblasStatus gpu_dtrsm128(
11061106
if (order != clblasColumnMajor)
11071107
return clblasNotImplemented;
11081108

1109-
//for now
1110-
if (side == clblasRight)
1111-
return clblasNotImplemented;
11121109

11131110
int inner_block_size = 16; // inner blocking size, <=32
11141111
int outer_block_size = 128;// outer blocking size, >BLOCK_SIZE
@@ -1285,8 +1282,147 @@ static clblasStatus gpu_dtrsm128(
12851282
}
12861283
else
12871284
{
1288-
clReleaseMemObject(X);
1289-
return clblasNotImplemented;
1285+
//
1286+
// Helper for C = alpha * B * A + beta * C
1287+
//
1288+
// In the calls below
1289+
// - the 2nd matrix shall be either A or InvA transposed according to transA
1290+
// - the 1st and 3rd matrices are either B and X
1291+
//
1292+
#define DGEMM_RIGHT(m,n,k, alpha, B, A, beta, C ) \
1293+
do { \
1294+
err = clblasDgemm(clblasColumnMajor, clblasNoTrans, transA , m, n, k, alpha, B, A, beta, C , 1, commandQueues, 0, NULL, events ) ; \
1295+
CL_CHECK(err); \
1296+
} while(0)
1297+
1298+
1299+
// side=R
1300+
/* invert the diagonals
1301+
* Allocate device memory for the inverted diagonal blocks, size=n*BLOCK_SIZE
1302+
*/
1303+
1304+
/* invert the diagonals
1305+
* Allocate device memory for the inverted diagonal blocks, size=m*nb
1306+
*/
1307+
size_t ldInvA = outer_block_size;
1308+
size_t offInvA = 0; //must be 0: needed by the _(X,i,j) macro
1309+
size_t size_InvA = ldInvA * BLOCKS(N, outer_block_size) * outer_block_size *sizeof(double);
1310+
InvA = clCreateBuffer(context, CL_MEM_READ_WRITE, size_InvA, NULL, &err);
1311+
CL_CHECK(err);
1312+
err = clearBuffer(commandQueues[0], InvA, size_InvA);
1313+
CL_CHECK(err);
1314+
1315+
err = diag_dtrtri128(commandQueues[0], N, uplo, diag, A, offA, InvA, ldA, inner_block_size, outer_block_size, events);
1316+
CL_CHECK(err);
1317+
1318+
1319+
if (transA == clblasNoTrans)
1320+
{
1321+
/* the non-transpose case */
1322+
if (uplo == clblasLower)
1323+
{
1324+
/* the lower case */
1325+
/* handle the first block seperately with alpha */
1326+
1327+
int nn = (N % outer_block_size == 0) ? outer_block_size : (N % outer_block_size);
1328+
i = N - nn;
1329+
DGEMM_RIGHT(M, nn, nn, alpha, _(B, 0, i), _(InvA, 0, i), zero, _(X, 0, i));
1330+
1331+
if (i - outer_block_size >= 0)
1332+
{
1333+
1334+
DGEMM_RIGHT(M, i, nn, neg_one, _(X, 0, i), _(A, i, 0), alpha, _(B, 0, 0));
1335+
1336+
/* the rest blocks */
1337+
for (i = N - nn - outer_block_size; i >= 0; i -= outer_block_size) {
1338+
DGEMM_RIGHT(M, outer_block_size, outer_block_size, one, _(B, 0, i), _(InvA, 0, i), zero, _(X, 0, i));
1339+
1340+
if (i - outer_block_size < 0)
1341+
break;
1342+
1343+
DGEMM_RIGHT(M, i, outer_block_size, neg_one, _(X, 0, i), _(A, i, 0), one, _(B, 0, 0));
1344+
}
1345+
}
1346+
}
1347+
else
1348+
{
1349+
/* the upper case */
1350+
/* handle the first block seperately with alpha */
1351+
int nn = min(outer_block_size, (int)N);
1352+
DGEMM_RIGHT(M, nn, nn, alpha, _(B, 0, 0), _(InvA, 0, 0), zero, _(X, 0, 0));
1353+
1354+
if (outer_block_size < N)
1355+
{
1356+
1357+
DGEMM_RIGHT(M, N - outer_block_size, outer_block_size, neg_one, _(X, 0, 0), _(A, 0, outer_block_size), alpha, _(B, 0, outer_block_size));
1358+
1359+
/* the rest blocks */
1360+
for (i = outer_block_size; i < N; i += outer_block_size) {
1361+
nn = min(outer_block_size, (int)N - i);
1362+
DGEMM_RIGHT(M, nn, nn, one, _(B, 0, i), _(InvA, 0, i), zero, _(X, 0, i));
1363+
1364+
if (i + outer_block_size >= N)
1365+
break;
1366+
1367+
DGEMM_RIGHT(M, N - i - outer_block_size, outer_block_size, neg_one, _(X, 0, i), _(A, i, i + outer_block_size), one, _(B, 0, i + outer_block_size));
1368+
}
1369+
}
1370+
}
1371+
}
1372+
else
1373+
{
1374+
1375+
/* the transpose case */
1376+
if (uplo == clblasLower)
1377+
{
1378+
/* the lower case */
1379+
/* handle the first block seperately with alpha */
1380+
1381+
int nn = min(outer_block_size, (int)N);
1382+
DGEMM_RIGHT(M, nn, nn, alpha, _(B, 0, 0), _(InvA, 0, 0), zero, _(X, 0, 0));
1383+
1384+
if (outer_block_size < N)
1385+
{
1386+
1387+
DGEMM_RIGHT(M, N - outer_block_size, outer_block_size, neg_one, _(X, 0, 0), _(A, outer_block_size, 0), alpha, _(B, 0, outer_block_size));
1388+
1389+
/* the rest blocks */
1390+
for (i = outer_block_size; i < N; i += outer_block_size) {
1391+
nn = min(outer_block_size, (int)N - i);
1392+
DGEMM_RIGHT(M, nn, nn, one, _(B, 0, i), _(InvA, 0, i), zero, _(X, 0, i));
1393+
1394+
if (i + outer_block_size >= N)
1395+
break;
1396+
1397+
DGEMM_RIGHT(M, N - i - outer_block_size, outer_block_size, neg_one, _(X, 0, i), _(A, outer_block_size + i, i), one, _(B, 0, i + outer_block_size));
1398+
}
1399+
}
1400+
}
1401+
else
1402+
{
1403+
/* the upper case */
1404+
/* handle the first block seperately with alpha */
1405+
int nn = (N % outer_block_size == 0) ? outer_block_size : (N % outer_block_size);
1406+
i = N - nn;
1407+
DGEMM_RIGHT(M, nn, nn, alpha, _(B, 0, i), _(InvA, 0, i), zero, _(X, 0, i));
1408+
1409+
if (i - outer_block_size >= 0)
1410+
{
1411+
1412+
DGEMM_RIGHT(M, i, nn, neg_one, _(X, 0, i), _(A, 0, i), alpha, _(B, 0, 0));
1413+
1414+
/* the rest blocks */
1415+
for (i = N - nn - outer_block_size; i >= 0; i -= outer_block_size) {
1416+
DGEMM_RIGHT(M, outer_block_size, outer_block_size, one, _(B, 0, i), _(InvA, 0, i), zero, _(X, 0, i));
1417+
1418+
if (i - outer_block_size < 0)
1419+
break;
1420+
1421+
DGEMM_RIGHT(M, i, outer_block_size, neg_one, _(X, 0, i), _(A, 0, i), one, _(B, 0, 0));
1422+
}
1423+
}
1424+
}
1425+
}
12901426
}
12911427

12921428
// Copy X(m,n) to B(m,n)

0 commit comments

Comments
 (0)