Skip to content

Commit d6e6a78

Browse files
author
Timmy
committed
dtrsm reenablment 192
1 parent afe8fc0 commit d6e6a78

13 files changed

+1963
-20
lines changed

src/library/blas/trtri/TrtriClKernels.h

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,25 @@
33
#define TRTRI_CL_KERNELS_H
44
#include "CL/cl.h"
55

6+
/*mod 192 dtrsm*/
67
static cl_kernel diag_dtrtri_upper_192_12_clKernel = NULL;
7-
static cl_kernel triple_dgemm_update_192_12_R = NULL;
8-
static cl_kernel triple_dgemm_update_192_24_PART1_R = NULL;
9-
static cl_kernel triple_dgemm_update_192_24_PART2_R = NULL;
10-
static cl_kernel triple_dgemm_update_192_48_PART1_R = NULL;
11-
static cl_kernel triple_dgemm_update_192_48_PART2_R = NULL;
12-
static cl_kernel triple_dgemm_update_192_96_PART1_R = NULL;
13-
static cl_kernel triple_dgemm_update_192_96_PART2_R = NULL;
8+
static cl_kernel triple_dgemm_update_192_12_R_clKernel = NULL;
9+
static cl_kernel triple_dgemm_update_192_24_PART1_R_clKernel = NULL;
10+
static cl_kernel triple_dgemm_update_192_24_PART2_R_clKernel = NULL;
11+
static cl_kernel triple_dgemm_update_192_48_PART1_R_clKernel = NULL;
12+
static cl_kernel triple_dgemm_update_192_48_PART2_R_clKernel = NULL;
13+
static cl_kernel triple_dgemm_update_192_96_PART1_R_clKernel = NULL;
14+
static cl_kernel triple_dgemm_update_192_96_PART2_R_clKernel = NULL;
1415

16+
/*mod 128 dtrsm*/
17+
static cl_kernel diag_dtrtri_upper_128_16_clKernel = NULL;
18+
static cl_kernel triple_dgemm_update_128_16_R_clKernel = NULL;
19+
static cl_kernel triple_dgemm_update_128_32_PART1_R_clKernel = NULL;
20+
static cl_kernel triple_dgemm_update_128_32_PART2_R_clKernel = NULL;
21+
static cl_kernel triple_dgemm_update_128_64_PART1_R_clKernel = NULL;
22+
static cl_kernel triple_dgemm_update_128_64_PART2_R_clKernel = NULL;
23+
static cl_kernel triple_dgemm_update_128_ABOVE64_PART1_R_clKernel = NULL;
24+
static cl_kernel triple_dgemm_update_128_ABOVE64_PART2_R_clKernel = NULL;
25+
static cl_kernel triple_dgemm_update_128_ABOVE64_PART3_R_clKernel = NULL;
1526

1627
#endif

src/library/blas/trtri/TrtriKernelSourceIncludes.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#ifndef TRTRI_SOURCE_INCLUDES_CPP
77
#define TRTRI_SOURCE_INCLUDES_CPP
88

9+
/*mod 192 dtrsm*/
910
#include "diag_dtrtri_upper_192_12.cpp"
1011
#include "triple_dgemm_update_192_12_R.cpp"
1112
#include "triple_dgemm_update_192_24_PART1_R.cpp"
@@ -15,5 +16,15 @@
1516
#include "triple_dgemm_update_192_96_PART1_R.cpp"
1617
#include "triple_dgemm_update_192_96_PART2_R.cpp"
1718

19+
/*mod 128 dtrsm*/
20+
#include "diag_dtrtri_upper_128_16.cpp"
21+
#include "triple_dgemm_update_128_16_R.cpp"
22+
#include "triple_dgemm_update_128_32_PART1_R.cpp"
23+
#include "triple_dgemm_update_128_32_PART2_R.cpp"
24+
#include "triple_dgemm_update_128_64_PART1_R.cpp"
25+
#include "triple_dgemm_update_128_64_PART2_R.cpp"
26+
#include "triple_dgemm_update_128_ABOVE64_PART1_R.cpp"
27+
#include "triple_dgemm_update_128_ABOVE64_PART2_R.cpp"
28+
#include "triple_dgemm_update_128_ABOVE64_PART3_R.cpp"
1829

1930
#endif

src/library/blas/trtri/TrtriKernelSourceIncludes.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
const char * const TrtriBuildOptions = "-cl-std=CL2.0";
1010
const char * const TrtribinBuildOptions = "-cl-std=CL2.0";
1111

12-
12+
/*mod 192 dtrsm*/
1313
extern const char * const diag_dtrtri_upper_192_12_src;
1414
extern unsigned char *diag_dtrtri_upper_192_12_bin;
1515
extern size_t diag_dtrtri_upper_192_12_binSize;
@@ -42,4 +42,41 @@ extern const char * const triple_dgemm_update_192_96_PART2_R_src;
4242
extern unsigned char *triple_dgemm_update_192_96_PART2_R_bin;
4343
extern size_t triple_dgemm_update_192_96_PART2_R_binSize;
4444

45+
/*mod 128 dtrsm*/
46+
extern const char * const diag_dtrtri_upper_128_16_src;
47+
extern unsigned char *diag_dtrtri_upper_128_16_bin;
48+
extern size_t diag_dtrtri_upper_128_16_binSize;
49+
50+
extern const char * const triple_dgemm_update_128_16_R_src;
51+
extern unsigned char *triple_dgemm_update_128_16_R_bin;
52+
extern size_t triple_dgemm_update_128_16_R_binSize;
53+
54+
extern const char * const triple_dgemm_update_128_32_PART1_R_src;
55+
extern unsigned char *triple_dgemm_update_128_32_PART1_R_bin;
56+
extern size_t triple_dgemm_update_128_32_PART1_R_binSize;
57+
58+
extern const char * const triple_dgemm_update_128_32_PART2_R_src;
59+
extern unsigned char *triple_dgemm_update_128_32_PART2_R_bin;
60+
extern size_t triple_dgemm_update_128_32_PART2_R_binSize;
61+
62+
extern const char * const triple_dgemm_update_128_64_PART1_R_src;
63+
extern unsigned char *triple_dgemm_update_128_64_PART1_R_bin;
64+
extern size_t triple_dgemm_update_128_64_PART1_R_binSize;
65+
66+
extern const char * const triple_dgemm_update_128_64_PART2_R_src;
67+
extern unsigned char *triple_dgemm_update_128_64_PART2_R_bin;
68+
extern size_t triple_dgemm_update_128_64_PART2_R_binSize;
69+
70+
extern const char * const triple_dgemm_update_128_ABOVE64_PART1_R_src;
71+
extern unsigned char *triple_dgemm_update_128_ABOVE64_PART1_R_bin;
72+
extern size_t triple_dgemm_update_128_ABOVE64_PART1_R_binSize;
73+
74+
extern const char * const triple_dgemm_update_128_ABOVE64_PART2_R_src;
75+
extern unsigned char *triple_dgemm_update_128_ABOVE64_PART2_R_bin;
76+
extern size_t triple_dgemm_update_128_ABOVE64_PART2_R_binSize;
77+
78+
extern const char * const triple_dgemm_update_128_ABOVE64_PART3_R_src;
79+
extern unsigned char *triple_dgemm_update_128_ABOVE64_PART3_R_bin;
80+
extern size_t triple_dgemm_update_128_ABOVE64_PART3_R_binSize;
81+
4582
#endif
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*******************************************************************************
2+
* Hand-tuned kernel
3+
******************************************************************************/
4+
5+
#ifndef KERNEL_DIAG_DTRTRI_UPPER_128_16_SRC_CPP
6+
#define KERNEL_DIAG_DTRTRI_UPPER_128_16_SRC_CPP
7+
#pragma message("#define KERNEL_DIAG_DTRTRI_UPPER_128_16_SRC_CPP.")
8+
9+
#ifndef STRINGIFY
10+
#define STRINGIFY2(...) #__VA_ARGS__
11+
#define STRINGIFY(...) STRINGIFY2(__VA_ARGS__)
12+
#endif
13+
14+
unsigned char *diag_dtrtri_upper_128_16_bin = 0;
15+
size_t diag_dtrtri_upper_128_16_binSize = 0;
16+
17+
const char * const diag_dtrtri_upper_128_16_src = STRINGIFY(
18+
#define BLOCK_SIZE 16 \n
19+
#define NB 128 \n
20+
#define ZERO ( 0.0) \n
21+
#define ONE ( 1.0) \n
22+
#ifdef DOUBLE_PRECISION \n
23+
#ifdef cl_khr_fp64 \n
24+
#pragma OPENCL EXTENSION cl_khr_fp64 : enable \n
25+
#else \n
26+
#pragma OPENCL EXTENSION cl_amd_fp64 : enable \n
27+
#endif \n
28+
#endif \n
29+
__kernel void diag_dtrtri_upper_128_16_src(\n
30+
int isDiagUnit,\n
31+
__global double const * restrict A, \n
32+
uint offA, \n
33+
__global double *d_dinvA, \n
34+
uint lda, \n
35+
uint na)\n
36+
{ \n
37+
int i, j;\n
38+
double Ystx = 0;\n
39+
__local double *y = 0;\n
40+
double switcher;\n
41+
double neg_switcher;\n
42+
43+
// Thread index
44+
int tx = get_local_id(0);\n
45+
46+
// Thread index
47+
int gx = get_global_id(0);\n
48+
49+
// Block index
50+
int bx = get_group_id(0);\n
51+
52+
A = A + offA;\n
53+
54+
__global const double *Aoff = A + bx*lda*BLOCK_SIZE + bx*BLOCK_SIZE;\n
55+
int NumBLperNB = NB/BLOCK_SIZE;\n
56+
d_dinvA += bx/NumBLperNB*NB*NB + (bx % NumBLperNB)*(NB*BLOCK_SIZE + BLOCK_SIZE);\n
57+
58+
__local double Bs[BLOCK_SIZE*BLOCK_SIZE];\n
59+
__local double workspace[BLOCK_SIZE]; \n // workspace used to store the current working column
60+
61+
// load A
62+
#pragma unroll \n
63+
for( i=0; i < BLOCK_SIZE; i++ )\n
64+
{\n
65+
if(tx <= i && i+bx*BLOCK_SIZE < na )\n
66+
{\n
67+
Bs[i*BLOCK_SIZE+tx] = *(Aoff+i*lda+tx);\n
68+
}\n
69+
else\n
70+
{\n
71+
Bs[i*BLOCK_SIZE+tx] = ZERO;\n
72+
}\n
73+
}\n
74+
// read in the whole square block of my A and zero out the non data triangular
75+
76+
// Synchronize to make sure the matrices are loaded
77+
//__syncthreads();
78+
barrier(CLK_LOCAL_MEM_FENCE);\n
79+
80+
// solve the diagonals
81+
82+
if(isDiagUnit == 1)\n
83+
{\n
84+
Bs[tx*BLOCK_SIZE+tx] = ONE;\n
85+
}\n
86+
else\n
87+
{\n
88+
if( Bs[tx*BLOCK_SIZE+tx] == ZERO )\n
89+
{\n
90+
Bs[tx*BLOCK_SIZE+tx] = ONE; \n
91+
}\n
92+
else \n
93+
{\n
94+
Bs[tx*BLOCK_SIZE+tx] = ONE / ( Bs[tx*BLOCK_SIZE+tx]) ;\n
95+
}\n
96+
}\n
97+
98+
/* the upper case */
99+
for( i=0; i < BLOCK_SIZE; i++ ) {\n
100+
Ystx = ZERO;\n
101+
if( tx < i)\n
102+
{\n
103+
switcher = ONE;\n
104+
}\n
105+
else\n
106+
{\n
107+
switcher = ZERO;\n
108+
}\n
109+
110+
//dtrmv
111+
workspace[tx] = *(Bs+i*BLOCK_SIZE+tx);\n
112+
y = Bs+i*BLOCK_SIZE;\n
113+
114+
#pragma unroll\n
115+
//for( j=tx; j < i; j++ )
116+
for( j=0; j < i; j++ )\n
117+
{\n
118+
Ystx += switcher * (*(Bs+j*BLOCK_SIZE+tx)*workspace[j]);\n
119+
}\n
120+
121+
//sscal
122+
// if (tx != i) y[tx]=switcher*Ystx*(-Bs[i*BLOCK_SIZE+i]);
123+
124+
if( tx != i)\n
125+
{\n
126+
switcher = ONE;\n
127+
neg_switcher = ZERO;\n
128+
}\n
129+
else\n
130+
{\n
131+
switcher = ZERO;\n
132+
neg_switcher = ONE;\n
133+
}\n
134+
135+
y[tx] = switcher *Ystx*(-Bs[i*BLOCK_SIZE+i])+neg_switcher*y[tx];\n
136+
137+
// __syncthreads();
138+
barrier(CLK_LOCAL_MEM_FENCE);\n
139+
}\n
140+
141+
// write back A
142+
#pragma unroll\n
143+
for( i=0; i < BLOCK_SIZE; i++ )\n
144+
{\n
145+
*(d_dinvA+i*NB+tx) = Bs[i*BLOCK_SIZE+tx];\n
146+
}\n
147+
148+
}\n
149+
);
150+
#endif

0 commit comments

Comments
 (0)