Skip to content

Commit 5c95eaa

Browse files
committed
Add templates for double and float types
1 parent 79c2afa commit 5c95eaa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2410
-724
lines changed

cuSten/src/kernels/2d_x_np_fun_kernel.cu

Lines changed: 77 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,18 @@
3737
// Function pointer definition
3838
// ---------------------------------------------------------------------
3939

40-
/*! typedef double (*devArg1X)(double*, double*, int);
40+
/*! typedef elemType (*devArg1X)(elemType*, elemType*, int);
4141
\brief The function pointer containing the user defined function to be applied <br>
4242
Input 1: The pointer to input data to the function <br>
4343
Input 2: The pointer to the coefficients provided by the user <br>
4444
Input 3: The current index position (centre of the stencil to be applied)
4545
*/
4646

47-
typedef double (*devArg1X)(double*, double*, int);
47+
template <typename elemType>
48+
struct templateFunc
49+
{
50+
typedef elemType (*devArg1X)(elemType*, elemType*, int);
51+
};
4852

4953
// ---------------------------------------------------------------------
5054
// Kernel Definition
@@ -65,12 +69,13 @@ typedef double (*devArg1X)(double*, double*, int);
6569
\param nx Total number of points in the x direction
6670
*/
6771

72+
template <typename elemType>
6873
__global__ void kernel2DXnpFun
6974
(
70-
double* dataOutput,
71-
double* dataInput,
72-
double* coe,
73-
double* func,
75+
elemType* dataOutput,
76+
elemType* dataInput,
77+
elemType* coe,
78+
elemType* func,
7479
const int numStenLeft,
7580
const int numStenRight,
7681
const int numCoe,
@@ -83,8 +88,8 @@ __global__ void kernel2DXnpFun
8388
// Allocate the shared memory
8489
extern __shared__ int memory[];
8590

86-
double* arrayLocal = (double*)&memory;
87-
double* coeLocal = (double*)&arrayLocal[nxLocal * nyLocal];
91+
elemType* arrayLocal = (elemType*)&memory;
92+
elemType* coeLocal = (elemType*)&arrayLocal[nxLocal * nyLocal];
8893

8994
// Move the weigths into shared memory
9095
#pragma unroll
@@ -94,7 +99,7 @@ __global__ void kernel2DXnpFun
9499
}
95100

96101
// True matrix index
97-
int globalIdx = blockDim.x * blockIdx.x + threadIdx.x;
102+
int globalIdx = blockDim.x * blockIdx.x + threadIdx.x;
98103
int globalIdy = blockDim.y * blockIdx.y + threadIdx.y;
99104

100105
// Local matrix index
@@ -125,7 +130,7 @@ __global__ void kernel2DXnpFun
125130

126131
stenSet = localIdy * nxLocal + localIdx;
127132

128-
dataOutput[globalIdy * nx + globalIdx] = ((devArg1X)func)(arrayLocal, coeLocal, stenSet);
133+
dataOutput[globalIdy * nx + globalIdx] = ((typename templateFunc<elemType>::devArg1X)func)(arrayLocal, coeLocal, stenSet);
129134
}
130135

131136
// Set all left boundary blocks
@@ -145,7 +150,7 @@ __global__ void kernel2DXnpFun
145150
{
146151
stenSet = localIdy * nxLocal + threadIdx.x;
147152

148-
dataOutput[globalIdy * nx + globalIdx] = ((devArg1X)func)(arrayLocal, coeLocal, stenSet);
153+
dataOutput[globalIdy * nx + globalIdx] = ((typename templateFunc<elemType>::devArg1X)func)(arrayLocal, coeLocal, stenSet);
149154
}
150155
}
151156

@@ -166,7 +171,7 @@ __global__ void kernel2DXnpFun
166171
{
167172
stenSet = localIdy * nxLocal + localIdx;
168173

169-
dataOutput[globalIdy * nx + globalIdx] = ((devArg1X)func)(arrayLocal, coeLocal, stenSet);
174+
dataOutput[globalIdy * nx + globalIdx] = ((typename templateFunc<elemType>::devArg1X)func)(arrayLocal, coeLocal, stenSet);
170175
}
171176
}
172177
}
@@ -181,9 +186,10 @@ __global__ void kernel2DXnpFun
181186
\param offload Set to HOST to move data back to CPU or DEVICE to keep on the GPU
182187
*/
183188

189+
template <typename elemType>
184190
void cuStenCompute2DXnpFun
185191
(
186-
cuSten_t* pt_cuSten,
192+
cuSten_t<elemType>* pt_cuSten,
187193
bool offload
188194
)
189195
{
@@ -199,14 +205,14 @@ void cuStenCompute2DXnpFun
199205
dim3 gridDim(pt_cuSten->xGrid, pt_cuSten->yGrid);
200206

201207
// Load the weights
202-
cudaMemPrefetchAsync(pt_cuSten->coe, pt_cuSten->numCoe * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
208+
cudaMemPrefetchAsync(pt_cuSten->coe, pt_cuSten->numCoe * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
203209

204210
// Preload the first block
205211
cudaStreamSynchronize(pt_cuSten->streams[1]);
206212

207213
// Prefetch the tile data
208-
cudaMemPrefetchAsync(pt_cuSten->dataInput[0], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
209-
cudaMemPrefetchAsync(pt_cuSten->dataOutput[0], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
214+
cudaMemPrefetchAsync(pt_cuSten->dataInput[0], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
215+
cudaMemPrefetchAsync(pt_cuSten->dataOutput[0], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
210216

211217
// Record the event
212218
cudaEventRecord(pt_cuSten->events[0], pt_cuSten->streams[1]);
@@ -246,8 +252,8 @@ void cuStenCompute2DXnpFun
246252
// Offload should the user want to
247253
if (offload == 1)
248254
{
249-
cudaMemPrefetchAsync(pt_cuSten->dataOutput[tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), cudaCpuDeviceId, pt_cuSten->streams[0]);
250-
cudaMemPrefetchAsync(pt_cuSten->dataInput[tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), cudaCpuDeviceId, pt_cuSten->streams[0]);
255+
cudaMemPrefetchAsync(pt_cuSten->dataOutput[tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), cudaCpuDeviceId, pt_cuSten->streams[0]);
256+
cudaMemPrefetchAsync(pt_cuSten->dataInput[tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), cudaCpuDeviceId, pt_cuSten->streams[0]);
251257
}
252258

253259
// Load the next tile
@@ -257,8 +263,8 @@ void cuStenCompute2DXnpFun
257263
cudaStreamSynchronize(pt_cuSten->streams[1]);
258264

259265
// Prefetch the necessary tiles
260-
cudaMemPrefetchAsync(pt_cuSten->dataOutput[tile + 1], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
261-
cudaMemPrefetchAsync(pt_cuSten->dataInput[tile + 1], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
266+
cudaMemPrefetchAsync(pt_cuSten->dataOutput[tile + 1], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
267+
cudaMemPrefetchAsync(pt_cuSten->dataInput[tile + 1], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
262268

263269
// Record the event
264270
cudaEventRecord(pt_cuSten->events[1], pt_cuSten->streams[1]);
@@ -277,6 +283,56 @@ void cuStenCompute2DXnpFun
277283
}
278284
}
279285

286+
// ---------------------------------------------------------------------
287+
// Explicit instantiation
288+
// ---------------------------------------------------------------------
289+
290+
template
291+
__global__ void kernel2DXnpFun<double>
292+
(
293+
double*,
294+
double*,
295+
double*,
296+
double*,
297+
const int,
298+
const int,
299+
const int,
300+
const int,
301+
const int,
302+
const int,
303+
const int
304+
);
305+
306+
template
307+
void cuStenCompute2DXnpFun<double>
308+
(
309+
cuSten_t<double>*,
310+
bool
311+
);
312+
313+
template
314+
__global__ void kernel2DXnpFun<float>
315+
(
316+
float*,
317+
float*,
318+
float*,
319+
float*,
320+
const int,
321+
const int,
322+
const int,
323+
const int,
324+
const int,
325+
const int,
326+
const int
327+
);
328+
329+
template
330+
void cuStenCompute2DXnpFun<float>
331+
(
332+
cuSten_t<float>*,
333+
bool
334+
);
335+
280336
// ---------------------------------------------------------------------
281337
// End of file
282-
// ---------------------------------------------------------------------
338+
// ---------------------------------------------------------------------

cuSten/src/kernels/2d_x_np_kernel.cu

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,12 @@
5252
\param nx Total number of points in the x direction
5353
*/
5454

55+
template <typename elemType>
5556
__global__ void kernel2DXnp
5657
(
57-
double* dataOutput,
58-
double* dataInput,
59-
const double* weights,
58+
elemType* dataOutput,
59+
elemType* dataInput,
60+
const elemType* weights,
6061
const int numSten,
6162
const int numStenLeft,
6263
const int numStenRight,
@@ -69,8 +70,8 @@ __global__ void kernel2DXnp
6970
// Allocate the shared memory
7071
extern __shared__ int memory[];
7172

72-
double* arrayLocal = (double*)&memory;
73-
double* weigthsLocal = (double*)&arrayLocal[nxLocal * nyLocal];
73+
elemType* arrayLocal = (elemType*)&memory;
74+
elemType* weigthsLocal = (elemType*)&arrayLocal[nxLocal * nyLocal];
7475

7576
// Move the weigths into shared memory
7677
#pragma unroll
@@ -80,15 +81,15 @@ __global__ void kernel2DXnp
8081
}
8182

8283
// True matrix index
83-
int globalIdx = blockDim.x * blockIdx.x + threadIdx.x;
84+
int globalIdx = blockDim.x * blockIdx.x + threadIdx.x;
8485
int globalIdy = blockDim.y * blockIdx.y + threadIdx.y;
8586

8687
// Local matrix index
8788
int localIdx = threadIdx.x + numStenLeft;
8889
int localIdy = threadIdx.y;
8990

9091
// Local sum variable
91-
double sum = 0.0;
92+
elemType sum = 0.0;
9293

9394
// Set index for summing stencil
9495
int stenSet;
@@ -187,10 +188,10 @@ __global__ void kernel2DXnp
187188
\param offload Set to HOST to move data back to CPU or DEVICE to keep on the GPU
188189
*/
189190

191+
template <typename elemType>
190192
void cuStenCompute2DXnp
191193
(
192-
cuSten_t* pt_cuSten,
193-
194+
cuSten_t<elemType>* pt_cuSten,
194195
bool offload
195196
)
196197
{
@@ -210,14 +211,14 @@ void cuStenCompute2DXnp
210211
int local_ny = pt_cuSten->BLOCK_Y;
211212

212213
// Load the weights
213-
cudaMemPrefetchAsync(pt_cuSten->weights, pt_cuSten->numSten * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
214+
cudaMemPrefetchAsync(pt_cuSten->weights, pt_cuSten->numSten * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
214215

215216
// Preload the first block
216217
cudaStreamSynchronize(pt_cuSten->streams[1]);
217218

218219
// Prefetch the tile data
219-
cudaMemPrefetchAsync(pt_cuSten->dataInput[0], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
220-
cudaMemPrefetchAsync(pt_cuSten->dataOutput[0], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
220+
cudaMemPrefetchAsync(pt_cuSten->dataInput[0], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
221+
cudaMemPrefetchAsync(pt_cuSten->dataOutput[0], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
221222

222223
// Record the event
223224
cudaEventRecord(pt_cuSten->events[0], pt_cuSten->streams[1]);
@@ -238,8 +239,8 @@ void cuStenCompute2DXnp
238239
// Offload should the user want to
239240
if (offload == 1)
240241
{
241-
cudaMemPrefetchAsync(pt_cuSten->dataOutput[tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), cudaCpuDeviceId, pt_cuSten->streams[0]);
242-
cudaMemPrefetchAsync(pt_cuSten->dataInput[tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), cudaCpuDeviceId, pt_cuSten->streams[0]);
242+
cudaMemPrefetchAsync(pt_cuSten->dataOutput[tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), cudaCpuDeviceId, pt_cuSten->streams[0]);
243+
cudaMemPrefetchAsync(pt_cuSten->dataInput[tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), cudaCpuDeviceId, pt_cuSten->streams[0]);
243244
}
244245

245246
// Load the next tile
@@ -249,8 +250,8 @@ void cuStenCompute2DXnp
249250
cudaStreamSynchronize(pt_cuSten->streams[1]);
250251

251252
// Prefetch the necessary tiles
252-
cudaMemPrefetchAsync(pt_cuSten->dataOutput[tile + 1], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
253-
cudaMemPrefetchAsync(pt_cuSten->dataInput[tile + 1], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(double), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
253+
cudaMemPrefetchAsync(pt_cuSten->dataOutput[tile + 1], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
254+
cudaMemPrefetchAsync(pt_cuSten->dataInput[tile + 1], pt_cuSten->nx * pt_cuSten->nyTile * sizeof(elemType), pt_cuSten->deviceNum, pt_cuSten->streams[1]);
254255

255256
// Record the event
256257
cudaEventRecord(pt_cuSten->events[1], pt_cuSten->streams[1]);
@@ -269,6 +270,54 @@ void cuStenCompute2DXnp
269270
}
270271
}
271272

273+
// ---------------------------------------------------------------------
274+
// Explicit instantiation
275+
// ---------------------------------------------------------------------
276+
277+
template
278+
__global__ void kernel2DXnp<double>
279+
(
280+
double*,
281+
double*,
282+
const double*,
283+
const int,
284+
const int,
285+
const int,
286+
const int,
287+
const int,
288+
const int,
289+
const int
290+
);
291+
292+
template
293+
void cuStenCompute2DXnp<double>
294+
(
295+
cuSten_t<double>*,
296+
bool
297+
);
298+
299+
template
300+
__global__ void kernel2DXnp<float>
301+
(
302+
float*,
303+
float*,
304+
const float*,
305+
const int,
306+
const int,
307+
const int,
308+
const int,
309+
const int,
310+
const int,
311+
const int
312+
);
313+
314+
template
315+
void cuStenCompute2DXnp<float>
316+
(
317+
cuSten_t<float>*,
318+
bool
319+
);
320+
272321
// ---------------------------------------------------------------------
273322
// End of file
274-
// ---------------------------------------------------------------------
323+
// ---------------------------------------------------------------------

0 commit comments

Comments
 (0)