3737// Function pointer definition
3838// ---------------------------------------------------------------------
3939
40- /* ! typedef double (*devArg1X)(double *, double *, int);
40+ /* ! typedef elemType (*devArg1X)(elemType *, elemType *, int);
4141 \brief The function pointer containing the user defined function to be applied <br>
4242 Input 1: The pointer to input data to the function <br>
4343 Input 2: The pointer to the coefficients provided by the user <br>
4444 Input 3: The current index position (centre of the stencil to be applied)
4545*/
4646
47- typedef double (*devArg1X)(double *, double *, int );
47+ template <typename elemType>
48+ struct templateFunc
49+ {
50+ typedef elemType (*devArg1X)(elemType*, elemType*, int );
51+ };
4852
4953// ---------------------------------------------------------------------
5054// Kernel Definition
@@ -65,12 +69,13 @@ typedef double (*devArg1X)(double*, double*, int);
6569 \param nx Total number of points in the x direction
6670*/
6771
72+ template <typename elemType>
6873__global__ void kernel2DXnpFun
6974(
70- double * dataOutput,
71- double * dataInput,
72- double * coe,
73- double * func,
75+ elemType * dataOutput,
76+ elemType * dataInput,
77+ elemType * coe,
78+ elemType * func,
7479 const int numStenLeft,
7580 const int numStenRight,
7681 const int numCoe,
@@ -83,8 +88,8 @@ __global__ void kernel2DXnpFun
8388 // Allocate the shared memory
8489 extern __shared__ int memory[];
8590
86- double * arrayLocal = (double *)&memory;
87- double * coeLocal = (double *)&arrayLocal[nxLocal * nyLocal];
91+ elemType * arrayLocal = (elemType *)&memory;
92+ elemType * coeLocal = (elemType *)&arrayLocal[nxLocal * nyLocal];
8893
8994 // Move the weigths into shared memory
9095 #pragma unroll
@@ -94,7 +99,7 @@ __global__ void kernel2DXnpFun
9499 }
95100
96101 // True matrix index
97- int globalIdx = blockDim .x * blockIdx .x + threadIdx .x ;
102+ int globalIdx = blockDim .x * blockIdx .x + threadIdx .x ;
98103 int globalIdy = blockDim .y * blockIdx .y + threadIdx .y ;
99104
100105 // Local matrix index
@@ -125,7 +130,7 @@ __global__ void kernel2DXnpFun
125130
126131 stenSet = localIdy * nxLocal + localIdx;
127132
128- dataOutput[globalIdy * nx + globalIdx] = ((devArg1X)func)(arrayLocal, coeLocal, stenSet);
133+ dataOutput[globalIdy * nx + globalIdx] = ((typename templateFunc<elemType>:: devArg1X)func)(arrayLocal, coeLocal, stenSet);
129134 }
130135
131136 // Set all left boundary blocks
@@ -145,7 +150,7 @@ __global__ void kernel2DXnpFun
145150 {
146151 stenSet = localIdy * nxLocal + threadIdx .x ;
147152
148- dataOutput[globalIdy * nx + globalIdx] = ((devArg1X)func)(arrayLocal, coeLocal, stenSet);
153+ dataOutput[globalIdy * nx + globalIdx] = ((typename templateFunc<elemType>:: devArg1X)func)(arrayLocal, coeLocal, stenSet);
149154 }
150155 }
151156
@@ -166,7 +171,7 @@ __global__ void kernel2DXnpFun
166171 {
167172 stenSet = localIdy * nxLocal + localIdx;
168173
169- dataOutput[globalIdy * nx + globalIdx] = ((devArg1X)func)(arrayLocal, coeLocal, stenSet);
174+ dataOutput[globalIdy * nx + globalIdx] = ((typename templateFunc<elemType>:: devArg1X)func)(arrayLocal, coeLocal, stenSet);
170175 }
171176 }
172177}
@@ -181,9 +186,10 @@ __global__ void kernel2DXnpFun
181186 \param offload Set to HOST to move data back to CPU or DEVICE to keep on the GPU
182187*/
183188
189+ template <typename elemType>
184190void cuStenCompute2DXnpFun
185191(
186- cuSten_t* pt_cuSten,
192+ cuSten_t<elemType> * pt_cuSten,
187193 bool offload
188194)
189195{
@@ -199,14 +205,14 @@ void cuStenCompute2DXnpFun
199205 dim3 gridDim (pt_cuSten->xGrid , pt_cuSten->yGrid );
200206
201207 // Load the weights
202- cudaMemPrefetchAsync (pt_cuSten->coe , pt_cuSten->numCoe * sizeof (double ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
208+ cudaMemPrefetchAsync (pt_cuSten->coe , pt_cuSten->numCoe * sizeof (elemType ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
203209
204210 // Preload the first block
205211 cudaStreamSynchronize (pt_cuSten->streams [1 ]);
206212
207213 // Prefetch the tile data
208- cudaMemPrefetchAsync (pt_cuSten->dataInput [0 ], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (double ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
209- cudaMemPrefetchAsync (pt_cuSten->dataOutput [0 ], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (double ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
214+ cudaMemPrefetchAsync (pt_cuSten->dataInput [0 ], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (elemType ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
215+ cudaMemPrefetchAsync (pt_cuSten->dataOutput [0 ], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (elemType ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
210216
211217 // Record the event
212218 cudaEventRecord (pt_cuSten->events [0 ], pt_cuSten->streams [1 ]);
@@ -246,8 +252,8 @@ void cuStenCompute2DXnpFun
246252 // Offload should the user want to
247253 if (offload == 1 )
248254 {
249- cudaMemPrefetchAsync (pt_cuSten->dataOutput [tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (double ), cudaCpuDeviceId, pt_cuSten->streams [0 ]);
250- cudaMemPrefetchAsync (pt_cuSten->dataInput [tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (double ), cudaCpuDeviceId, pt_cuSten->streams [0 ]);
255+ cudaMemPrefetchAsync (pt_cuSten->dataOutput [tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (elemType ), cudaCpuDeviceId, pt_cuSten->streams [0 ]);
256+ cudaMemPrefetchAsync (pt_cuSten->dataInput [tile], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (elemType ), cudaCpuDeviceId, pt_cuSten->streams [0 ]);
251257 }
252258
253259 // Load the next tile
@@ -257,8 +263,8 @@ void cuStenCompute2DXnpFun
257263 cudaStreamSynchronize (pt_cuSten->streams [1 ]);
258264
259265 // Prefetch the necessary tiles
260- cudaMemPrefetchAsync (pt_cuSten->dataOutput [tile + 1 ], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (double ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
261- cudaMemPrefetchAsync (pt_cuSten->dataInput [tile + 1 ], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (double ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
266+ cudaMemPrefetchAsync (pt_cuSten->dataOutput [tile + 1 ], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (elemType ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
267+ cudaMemPrefetchAsync (pt_cuSten->dataInput [tile + 1 ], pt_cuSten->nx * pt_cuSten->nyTile * sizeof (elemType ), pt_cuSten->deviceNum , pt_cuSten->streams [1 ]);
262268
263269 // Record the event
264270 cudaEventRecord (pt_cuSten->events [1 ], pt_cuSten->streams [1 ]);
@@ -277,6 +283,56 @@ void cuStenCompute2DXnpFun
277283 }
278284}
279285
286+ // ---------------------------------------------------------------------
287+ // Explicit instantiation
288+ // ---------------------------------------------------------------------
289+
290+ template
291+ __global__ void kernel2DXnpFun<double >
292+ (
293+ double *,
294+ double *,
295+ double *,
296+ double *,
297+ const int ,
298+ const int ,
299+ const int ,
300+ const int ,
301+ const int ,
302+ const int ,
303+ const int
304+ );
305+
306+ template
307+ void cuStenCompute2DXnpFun<double >
308+ (
309+ cuSten_t<double >*,
310+ bool
311+ );
312+
313+ template
314+ __global__ void kernel2DXnpFun<float >
315+ (
316+ float *,
317+ float *,
318+ float *,
319+ float *,
320+ const int ,
321+ const int ,
322+ const int ,
323+ const int ,
324+ const int ,
325+ const int ,
326+ const int
327+ );
328+
329+ template
330+ void cuStenCompute2DXnpFun<float >
331+ (
332+ cuSten_t<float >*,
333+ bool
334+ );
335+
280336// ---------------------------------------------------------------------
281337// End of file
282- // ---------------------------------------------------------------------
338+ // ---------------------------------------------------------------------
0 commit comments