Skip to content

Commit 5ee9e5f

Browse files
author
Timmy
committed
dtrsm lower left
1 parent 4067d14 commit 5ee9e5f

14 files changed

+1577
-50
lines changed

src/library/blas/trtri/TrtriClKernels.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ static cl_kernel triple_dgemm_update_192_96_PART1_R_clKernel = NULL;
1414
static cl_kernel triple_dgemm_update_192_96_PART2_R_clKernel = NULL;
1515

1616
/*mod 128 dtrsm*/
17+
/*upper*/
1718
static cl_kernel diag_dtrtri_upper_128_16_clKernel = NULL;
1819
static cl_kernel triple_dgemm_update_128_16_R_clKernel = NULL;
1920
static cl_kernel triple_dgemm_update_128_32_PART1_R_clKernel = NULL;
@@ -24,4 +25,16 @@ static cl_kernel triple_dgemm_update_128_ABOVE64_PART1_R_clKernel = NULL;
2425
static cl_kernel triple_dgemm_update_128_ABOVE64_PART2_R_clKernel = NULL;
2526
static cl_kernel triple_dgemm_update_128_ABOVE64_PART3_R_clKernel = NULL;
2627

28+
/*lower*/
29+
static cl_kernel diag_dtrtri_lower_128_16_clKernel = NULL;
30+
static cl_kernel triple_dgemm_update_128_16_PART1_L_clKernel = NULL;
31+
static cl_kernel triple_dgemm_update_128_16_PART2_L_clKernel = NULL;
32+
static cl_kernel triple_dgemm_update_128_32_PART1_L_clKernel = NULL;
33+
static cl_kernel triple_dgemm_update_128_32_PART2_L_clKernel = NULL;
34+
static cl_kernel triple_dgemm_update_128_64_PART1_L_clKernel = NULL;
35+
static cl_kernel triple_dgemm_update_128_64_PART2_L_clKernel = NULL;
36+
static cl_kernel triple_dgemm_update_128_ABOVE64_PART1_L_clKernel = NULL;
37+
static cl_kernel triple_dgemm_update_128_ABOVE64_PART2_L_clKernel = NULL;
38+
static cl_kernel triple_dgemm_update_128_ABOVE64_PART3_L_clKernel = NULL;
39+
2740
#endif

src/library/blas/trtri/TrtriKernelSourceIncludes.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "triple_dgemm_update_192_96_PART2_R.cpp"
1818

1919
/*mod 128 dtrsm*/
20+
/*upper*/
2021
#include "diag_dtrtri_upper_128_16.cpp"
2122
#include "triple_dgemm_update_128_16_R.cpp"
2223
#include "triple_dgemm_update_128_32_PART1_R.cpp"
@@ -27,4 +28,16 @@
2728
#include "triple_dgemm_update_128_ABOVE64_PART2_R.cpp"
2829
#include "triple_dgemm_update_128_ABOVE64_PART3_R.cpp"
2930

31+
/*lower*/
32+
#include "diag_dtrtri_lower_128_16.cpp"
33+
#include "triple_dgemm_update_128_16_PART1_L.cpp"
34+
#include "triple_dgemm_update_128_16_PART2_L.cpp"
35+
#include "triple_dgemm_update_128_32_PART1_L.cpp"
36+
#include "triple_dgemm_update_128_32_PART2_L.cpp"
37+
#include "triple_dgemm_update_128_64_PART1_L.cpp"
38+
#include "triple_dgemm_update_128_64_PART2_L.cpp"
39+
#include "triple_dgemm_update_128_ABOVE64_PART1_L.cpp"
40+
#include "triple_dgemm_update_128_ABOVE64_PART2_L.cpp"
41+
#include "triple_dgemm_update_128_ABOVE64_PART3_L.cpp"
42+
3043
#endif

src/library/blas/trtri/TrtriKernelSourceIncludes.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ extern unsigned char *triple_dgemm_update_192_96_PART2_R_bin;
4343
extern size_t triple_dgemm_update_192_96_PART2_R_binSize;
4444

4545
/*mod 128 dtrsm*/
46+
/*upper*/
4647
extern const char * const diag_dtrtri_upper_128_16_src;
4748
extern unsigned char *diag_dtrtri_upper_128_16_bin;
4849
extern size_t diag_dtrtri_upper_128_16_binSize;
@@ -79,4 +80,45 @@ extern const char * const triple_dgemm_update_128_ABOVE64_PART3_R_src;
7980
extern unsigned char *triple_dgemm_update_128_ABOVE64_PART3_R_bin;
8081
extern size_t triple_dgemm_update_128_ABOVE64_PART3_R_binSize;
8182

83+
/*lower*/
84+
extern const char * const diag_dtrtri_lower_128_16_src;
85+
extern unsigned char *diag_dtrtri_lower_128_16_bin;
86+
extern size_t diag_dtrtri_lower_128_16_binSize;
87+
88+
extern const char * const triple_dgemm_update_128_16_PART1_L_src;
89+
extern unsigned char *triple_dgemm_update_128_16_PART1_L_bin;
90+
extern size_t triple_dgemm_update_128_16_PART1_L_binSize;
91+
92+
extern const char * const triple_dgemm_update_128_16_PART2_L_src;
93+
extern unsigned char *triple_dgemm_update_128_16_PART2_L_bin;
94+
extern size_t triple_dgemm_update_128_16_PART2_L_binSize;
95+
96+
extern const char * const triple_dgemm_update_128_32_PART1_L_src;
97+
extern unsigned char *triple_dgemm_update_128_32_PART1_L_bin;
98+
extern size_t triple_dgemm_update_128_32_PART1_L_binSize;
99+
100+
extern const char * const triple_dgemm_update_128_32_PART2_L_src;
101+
extern unsigned char *triple_dgemm_update_128_32_PART2_L_bin;
102+
extern size_t triple_dgemm_update_128_32_PART2_L_binSize;
103+
104+
extern const char * const triple_dgemm_update_128_64_PART1_L_src;
105+
extern unsigned char *triple_dgemm_update_128_64_PART1_L_bin;
106+
extern size_t triple_dgemm_update_128_64_PART1_L_binSize;
107+
108+
extern const char * const triple_dgemm_update_128_64_PART2_L_src;
109+
extern unsigned char *triple_dgemm_update_128_64_PART2_L_bin;
110+
extern size_t triple_dgemm_update_128_64_PART2_L_binSize;
111+
112+
extern const char * const triple_dgemm_update_128_ABOVE64_PART1_L_src;
113+
extern unsigned char *triple_dgemm_update_128_ABOVE64_PART1_L_bin;
114+
extern size_t triple_dgemm_update_128_ABOVE64_PART1_L_binSize;
115+
116+
extern const char * const triple_dgemm_update_128_ABOVE64_PART2_L_src;
117+
extern unsigned char *triple_dgemm_update_128_ABOVE64_PART2_L_bin;
118+
extern size_t triple_dgemm_update_128_ABOVE64_PART2_L_binSize;
119+
120+
extern const char * const triple_dgemm_update_128_ABOVE64_PART3_L_src;
121+
extern unsigned char *triple_dgemm_update_128_ABOVE64_PART3_L_bin;
122+
extern size_t triple_dgemm_update_128_ABOVE64_PART3_L_binSize;
123+
82124
#endif
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/*******************************************************************************
2+
* Hand-tuned kernel
3+
******************************************************************************/
4+
5+
#ifndef KERNEL_DIAG_DTRTRI_LOWER_128_16_SRC_CPP
6+
#define KERNEL_DIAG_DTRTRI_LOWER_128_16_SRC_CPP
7+
#pragma message("#define KERNEL_DIAG_DTRTRI_UPPER_128_16_SRC_CPP.")
8+
9+
#ifndef STRINGIFY
10+
#define STRINGIFY2(...) #__VA_ARGS__
11+
#define STRINGIFY(...) STRINGIFY2(__VA_ARGS__)
12+
#endif
13+
14+
unsigned char *diag_dtrtri_lower_128_16_bin = 0;
15+
size_t diag_dtrtri_lower_128_16_binSize = 0;
16+
17+
const char * const diag_dtrtri_lower_128_16_src = STRINGIFY(
18+
#define BLOCK_SIZE 16 \n
19+
#define NB 128 \n
20+
#define ZERO ( 0.0) \n
21+
#define ONE ( 1.0) \n
22+
#ifdef DOUBLE_PRECISION \n
23+
#ifdef cl_khr_fp64 \n
24+
#pragma OPENCL EXTENSION cl_khr_fp64 : enable \n
25+
#else \n
26+
#pragma OPENCL EXTENSION cl_amd_fp64 : enable \n
27+
#endif \n
28+
#endif \n
29+
__kernel void diag_dtrtri_lower_128_16_src(\n
30+
int isDiagUnit,\n
31+
__global double const * restrict A, \n
32+
uint offA, \n
33+
__global double *d_dinvA, \n
34+
uint lda, \n
35+
uint na)\n
36+
{ \n
37+
int i, j;\n
38+
double Ystx = 0; \n
39+
__local double *Bw = 0, *x = 0, *y = 0; \n
40+
double switcher; \n
41+
double neg_switcher; \n
42+
43+
44+
// Thread index
45+
int tx = get_local_id(0); \n
46+
int txw; \n
47+
48+
int gx = get_global_id(0); \n
49+
50+
// Block index
51+
int bx = get_group_id(0); \n
52+
53+
A = A + offA; \n
54+
55+
__global const double *Aoff = A + bx*lda*BLOCK_SIZE + bx*BLOCK_SIZE; \n
56+
int NumBLperNB = NB / BLOCK_SIZE; \n
57+
d_dinvA += bx / NumBLperNB*NB*NB + (bx % NumBLperNB)*(NB*BLOCK_SIZE + BLOCK_SIZE); \n
58+
59+
__local double Bs[BLOCK_SIZE*BLOCK_SIZE]; \n
60+
__local double workspace[BLOCK_SIZE]; \n // workspace used to store the current working column
61+
62+
// load A
63+
#pragma unroll\n
64+
for (i = 0; i < BLOCK_SIZE; i++)\n
65+
{ \n
66+
if (tx >= i && gx < na)\n
67+
{ \n
68+
Bs[i*BLOCK_SIZE + tx] = *(Aoff + i*lda + tx); \n
69+
}\n
70+
else\n
71+
{ \n
72+
Bs[i*BLOCK_SIZE + tx] = ZERO; \n
73+
}\n
74+
}\n
75+
76+
// read in the whole square block of my A and zero out the non data triangular
77+
// not the upper or lower diagonal
78+
79+
// Synchronize to make sure the matrices are loaded
80+
//__syncthreads();
81+
barrier(CLK_LOCAL_MEM_FENCE); \n
82+
83+
84+
// solve the diagonals
85+
86+
if (isDiagUnit == 1)\n
87+
{ \n
88+
Bs[tx*BLOCK_SIZE + tx] = ONE; \n
89+
}\n
90+
else\n
91+
{ \n
92+
if (Bs[tx*BLOCK_SIZE + tx] == ZERO)\n
93+
{ \n
94+
Bs[tx*BLOCK_SIZE + tx] = ONE; \n
95+
}\n
96+
else\n
97+
{ \n
98+
Bs[tx*BLOCK_SIZE + tx] = ONE / (Bs[tx*BLOCK_SIZE + tx]); \n
99+
}\n
100+
}\n
101+
102+
/*
103+
* the lower case
104+
*/
105+
106+
107+
if (!(tx < BLOCK_SIZE - 1))\n
108+
{ \n
109+
switcher = ONE; \n
110+
}\n
111+
else\n
112+
{ \n
113+
switcher = ZERO; \n
114+
}\n
115+
116+
Bs[(BLOCK_SIZE - 1)*BLOCK_SIZE + tx] = switcher * Bs[(BLOCK_SIZE - 1)*BLOCK_SIZE + tx]; \n
117+
// zero out the last column, except the diagonal element
118+
119+
for (i = BLOCK_SIZE - 2; i >= 0; i--) {\n
120+
Ystx = ZERO; \n
121+
122+
if (tx > i)\n
123+
{ \n
124+
switcher = ONE; \n
125+
}\n
126+
else\n
127+
{ \n
128+
switcher = ZERO; \n
129+
}\n
130+
131+
//dtrmv
132+
Bw = Bs + (i + 1)*BLOCK_SIZE + i + 1; \n
133+
workspace[tx] = *(Bs + i*BLOCK_SIZE + tx); \n
134+
x = workspace + i + 1; \n
135+
y = Bs + i*BLOCK_SIZE; \n
136+
137+
txw = (tx - i - 1); \n
138+
139+
#pragma unroll\n
140+
for (j = 0; j < BLOCK_SIZE - i - 1; j++)\n
141+
Ystx += switcher*(*(Bw + j*BLOCK_SIZE + txw)*x[j]); \n
142+
143+
//sscal
144+
145+
if (tx != i)\n
146+
{ \n
147+
switcher = ONE; \n
148+
neg_switcher = ZERO; \n
149+
}\n
150+
else\n
151+
{ \n
152+
switcher = ZERO; \n
153+
neg_switcher = ONE; \n
154+
}\n
155+
156+
y[tx] = switcher * Ystx*(-Bs[i*BLOCK_SIZE + i]) + neg_switcher *y[tx]; \n
157+
158+
//__syncthreads();
159+
barrier(CLK_LOCAL_MEM_FENCE); \n
160+
161+
}\n
162+
163+
// write back A
164+
#pragma unroll\n
165+
for (i = 0; i < BLOCK_SIZE; i++)\n
166+
*(d_dinvA + i*NB + tx) = Bs[i*BLOCK_SIZE + tx]; \n
167+
}\n
168+
);
169+
#endif

0 commit comments

Comments
 (0)