@@ -1106,9 +1106,6 @@ static clblasStatus gpu_dtrsm128(
1106
1106
if (order != clblasColumnMajor)
1107
1107
return clblasNotImplemented;
1108
1108
1109
- // for now
1110
- if (side == clblasRight)
1111
- return clblasNotImplemented;
1112
1109
1113
1110
int inner_block_size = 16 ; // inner blocking size, <=32
1114
1111
int outer_block_size = 128 ;// outer blocking size, >BLOCK_SIZE
@@ -1285,8 +1282,147 @@ static clblasStatus gpu_dtrsm128(
1285
1282
}
1286
1283
else
1287
1284
{
1288
- clReleaseMemObject (X);
1289
- return clblasNotImplemented;
1285
+ //
1286
+ // Helper for C = alpha * B * A + beta * C
1287
+ //
1288
+ // In the calls below
1289
+ // - the 2nd matrix shall be either A or InvA transposed according to transA
1290
+ // - the 1st and 3rd matrices are either B and X
1291
+ //
1292
+ #define DGEMM_RIGHT (m,n,k, alpha, B, A, beta, C ) \
1293
+ do { \
1294
+ err = clblasDgemm (clblasColumnMajor, clblasNoTrans, transA , m, n, k, alpha, B, A, beta, C , 1 , commandQueues, 0 , NULL , events ) ; \
1295
+ CL_CHECK (err); \
1296
+ } while (0 )
1297
+
1298
+
1299
+ // side=R
1300
+ /* invert the diagonals
1301
+ * Allocate device memory for the inverted diagonal blocks, size=n*BLOCK_SIZE
1302
+ */
1303
+
1304
+ /* invert the diagonals
1305
+ * Allocate device memory for the inverted diagonal blocks, size=m*nb
1306
+ */
1307
+ size_t ldInvA = outer_block_size;
1308
+ size_t offInvA = 0 ; // must be 0: needed by the _(X,i,j) macro
1309
+ size_t size_InvA = ldInvA * BLOCKS (N, outer_block_size) * outer_block_size *sizeof (double );
1310
+ InvA = clCreateBuffer (context, CL_MEM_READ_WRITE, size_InvA, NULL , &err);
1311
+ CL_CHECK (err);
1312
+ err = clearBuffer (commandQueues[0 ], InvA, size_InvA);
1313
+ CL_CHECK (err);
1314
+
1315
+ err = diag_dtrtri128 (commandQueues[0 ], N, uplo, diag, A, offA, InvA, ldA, inner_block_size, outer_block_size, events);
1316
+ CL_CHECK (err);
1317
+
1318
+
1319
+ if (transA == clblasNoTrans)
1320
+ {
1321
+ /* the non-transpose case */
1322
+ if (uplo == clblasLower)
1323
+ {
1324
+ /* the lower case */
1325
+ /* handle the first block seperately with alpha */
1326
+
1327
+ int nn = (N % outer_block_size == 0 ) ? outer_block_size : (N % outer_block_size);
1328
+ i = N - nn;
1329
+ DGEMM_RIGHT (M, nn, nn, alpha, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1330
+
1331
+ if (i - outer_block_size >= 0 )
1332
+ {
1333
+
1334
+ DGEMM_RIGHT (M, i, nn, neg_one, _ (X, 0 , i), _ (A, i, 0 ), alpha, _ (B, 0 , 0 ));
1335
+
1336
+ /* the rest blocks */
1337
+ for (i = N - nn - outer_block_size; i >= 0 ; i -= outer_block_size) {
1338
+ DGEMM_RIGHT (M, outer_block_size, outer_block_size, one, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1339
+
1340
+ if (i - outer_block_size < 0 )
1341
+ break ;
1342
+
1343
+ DGEMM_RIGHT (M, i, outer_block_size, neg_one, _ (X, 0 , i), _ (A, i, 0 ), one, _ (B, 0 , 0 ));
1344
+ }
1345
+ }
1346
+ }
1347
+ else
1348
+ {
1349
+ /* the upper case */
1350
+ /* handle the first block seperately with alpha */
1351
+ int nn = min (outer_block_size, (int )N);
1352
+ DGEMM_RIGHT (M, nn, nn, alpha, _ (B, 0 , 0 ), _ (InvA, 0 , 0 ), zero, _ (X, 0 , 0 ));
1353
+
1354
+ if (outer_block_size < N)
1355
+ {
1356
+
1357
+ DGEMM_RIGHT (M, N - outer_block_size, outer_block_size, neg_one, _ (X, 0 , 0 ), _ (A, 0 , outer_block_size), alpha, _ (B, 0 , outer_block_size));
1358
+
1359
+ /* the rest blocks */
1360
+ for (i = outer_block_size; i < N; i += outer_block_size) {
1361
+ nn = min (outer_block_size, (int )N - i);
1362
+ DGEMM_RIGHT (M, nn, nn, one, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1363
+
1364
+ if (i + outer_block_size >= N)
1365
+ break ;
1366
+
1367
+ DGEMM_RIGHT (M, N - i - outer_block_size, outer_block_size, neg_one, _ (X, 0 , i), _ (A, i, i + outer_block_size), one, _ (B, 0 , i + outer_block_size));
1368
+ }
1369
+ }
1370
+ }
1371
+ }
1372
+ else
1373
+ {
1374
+
1375
+ /* the transpose case */
1376
+ if (uplo == clblasLower)
1377
+ {
1378
+ /* the lower case */
1379
+ /* handle the first block seperately with alpha */
1380
+
1381
+ int nn = min (outer_block_size, (int )N);
1382
+ DGEMM_RIGHT (M, nn, nn, alpha, _ (B, 0 , 0 ), _ (InvA, 0 , 0 ), zero, _ (X, 0 , 0 ));
1383
+
1384
+ if (outer_block_size < N)
1385
+ {
1386
+
1387
+ DGEMM_RIGHT (M, N - outer_block_size, outer_block_size, neg_one, _ (X, 0 , 0 ), _ (A, outer_block_size, 0 ), alpha, _ (B, 0 , outer_block_size));
1388
+
1389
+ /* the rest blocks */
1390
+ for (i = outer_block_size; i < N; i += outer_block_size) {
1391
+ nn = min (outer_block_size, (int )N - i);
1392
+ DGEMM_RIGHT (M, nn, nn, one, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1393
+
1394
+ if (i + outer_block_size >= N)
1395
+ break ;
1396
+
1397
+ DGEMM_RIGHT (M, N - i - outer_block_size, outer_block_size, neg_one, _ (X, 0 , i), _ (A, outer_block_size + i, i), one, _ (B, 0 , i + outer_block_size));
1398
+ }
1399
+ }
1400
+ }
1401
+ else
1402
+ {
1403
+ /* the upper case */
1404
+ /* handle the first block seperately with alpha */
1405
+ int nn = (N % outer_block_size == 0 ) ? outer_block_size : (N % outer_block_size);
1406
+ i = N - nn;
1407
+ DGEMM_RIGHT (M, nn, nn, alpha, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1408
+
1409
+ if (i - outer_block_size >= 0 )
1410
+ {
1411
+
1412
+ DGEMM_RIGHT (M, i, nn, neg_one, _ (X, 0 , i), _ (A, 0 , i), alpha, _ (B, 0 , 0 ));
1413
+
1414
+ /* the rest blocks */
1415
+ for (i = N - nn - outer_block_size; i >= 0 ; i -= outer_block_size) {
1416
+ DGEMM_RIGHT (M, outer_block_size, outer_block_size, one, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1417
+
1418
+ if (i - outer_block_size < 0 )
1419
+ break ;
1420
+
1421
+ DGEMM_RIGHT (M, i, outer_block_size, neg_one, _ (X, 0 , i), _ (A, 0 , i), one, _ (B, 0 , 0 ));
1422
+ }
1423
+ }
1424
+ }
1425
+ }
1290
1426
}
1291
1427
1292
1428
// Copy X(m,n) to B(m,n)
0 commit comments