Skip to content

Commit 77e7b47

Browse files
committed
reduce header dependency
1 parent 985c069 commit 77e7b47

File tree

10 files changed

+79
-48
lines changed

10 files changed

+79
-48
lines changed

source/module_hamilt_lcao/module_gint/kernels/cuda/cuda_tools.cu

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,6 @@
44

55
#include "cuda_tools.cuh"
66

7-
cudaError_t check(cudaError_t result, const char *const func, const char *const file, const int line)
8-
{
9-
if (result != cudaSuccess)
10-
{
11-
fprintf(stderr, "CUDA Runtime Error at %s:%d code=%s \"%s\" \n", file, line, cudaGetErrorString(result), func);
12-
exit(EXIT_FAILURE);
13-
}
14-
return result;
15-
}
16-
cudaError_t __checkCudaLastError(const char *file, const int line)
17-
{
18-
cudaError_t result = cudaGetLastError();
19-
if (result != cudaSuccess)
20-
{
21-
fprintf(stderr, "%s(%i) : getLastCudaError():%s\n", file, line, cudaGetErrorString(result));
22-
assert(result == cudaSuccess);
23-
}
24-
return result;
25-
}
26-
277
void dump_cuda_array_to_file(const double* cuda_array,
288
int width,
299
int hight,

source/module_hamilt_lcao/module_gint/kernels/cuda/cuda_tools.cuh

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,32 @@
99
#include <iostream>
1010
#include <sstream>
1111

12-
#define checkCuda(val) check(val, #val, __FILE__, __LINE__)
13-
#define checkCudaLastError() __checkCudaLastError(__FILE__, __LINE__)
12+
#define checkCuda(val) check((val), #val, __FILE__, __LINE__)
13+
#define checkCudaLastError() __getLastCudaError(__FILE__, __LINE__)
1414

15-
cudaError_t check(cudaError_t result, const char *const func, const char *const file, const int line);
16-
cudaError_t __checkCudaLastError(const char *file, const int line);
15+
inline void check(cudaError_t result, char const *const func, const char *const file,
16+
int const line) {
17+
if (result) {
18+
fprintf(stderr, "CUDA error at %s:%d code=%d \"%s\" \n", file, line,
19+
static_cast<unsigned int>(result), func);
20+
exit(EXIT_FAILURE);
21+
}
22+
}
23+
24+
inline void __getLastCudaError(const char *file,
25+
const int line)
26+
{
27+
cudaError_t err = cudaGetLastError();
28+
29+
if (cudaSuccess != err) {
30+
fprintf(stderr,
31+
"%s(%i) : getLastCudaError() CUDA error :"
32+
" (%d) %s.\n",
33+
file, line, static_cast<int>(err),
34+
cudaGetErrorString(err));
35+
exit(EXIT_FAILURE);
36+
}
37+
}
1738

1839
static inline int ceildiv(int x, int y)
1940
{

source/module_hamilt_lcao/module_gint/temp_gint/gint_helper.h

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,11 @@
77

88
namespace ModuleGint
99
{
10-
11-
template <typename T>
12-
std::shared_ptr<const T> toConstSharedPtr(std::shared_ptr<T> ptr) {
13-
return std::static_pointer_cast<const T>(ptr);
14-
}
15-
16-
1710
inline int index3Dto1D(const int id_x, const int id_y, const int id_z,
1811
const int dim_x, const int dim_y, const int dim_z)
1912
{
2013
return id_z + id_y * dim_z + id_x * dim_y * dim_z;
21-
};
14+
}
2215

2316
inline Vec3i index1Dto3D(const int index_1d,
2417
const int dim_x, const int dim_y, const int dim_z)
@@ -27,7 +20,7 @@ inline Vec3i index1Dto3D(const int index_1d,
2720
int id_y = (index_1d - id_x * dim_y * dim_z) / dim_z;
2821
int id_z = index_1d % dim_z;
2922
return Vec3i(id_x, id_y, id_z);
30-
};
23+
}
3124

3225
// if exponent is an integer between 0 and 5 (the most common cases in gint) and
3326
// and exp is a variable that cannot be determined at compile time (which means the compiler cannot optimize the code),
@@ -52,17 +45,17 @@ inline double pow_int(const double base, const int exp)
5245
double result = std::pow(base, exp);
5346
return result;
5447
}
55-
};
48+
}
5649

5750
inline int floor_div(const int a, const int b)
5851
{
5952
// a ^ b < 0 means a and b have different signs
6053
return a / b - (a % b != 0 && (a ^ b) < 0);
61-
};
54+
}
6255

6356
inline int ceil_div(const int a, const int b)
6457
{
6558
return a / b + (a % b != 0 && (a ^ b) > 0);
66-
};
59+
}
6760

6861
}

source/module_hamilt_lcao/module_gint/temp_gint/kernel/cuda_mem_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22
#include <cuda_runtime.h>
33
#include "module_base/tool_quit.h"
4-
#include "module_hamilt_lcao/module_gint/kernels/cuda/cuda_tools.cuh"
4+
#include "gint_helper.cuh"
55

66
template <typename T>
77
class CudaMemWrapper

source/module_hamilt_lcao/module_gint/temp_gint/kernel/gemm_nn_vbatch.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <cuda_runtime.h>
77
#include <stdio.h> // for fprintf and stderr
88

9-
#include "module_hamilt_lcao/module_gint/kernels/cuda/cuda_tools.cuh"
9+
#include "gint_helper.cuh"
1010
#include <functional>
1111

1212

@@ -395,8 +395,8 @@ void vbatched_gemm_nn_impl(int max_m,
395395
for (int i = 0; i < batchCount; i += max_batch_count)
396396
{
397397
const int ibatch = min(max_batch_count, batchCount - i);
398-
dim3 dimGrid(ceildiv(max_n, BLK_M),
399-
ceildiv(max_m, BLK_N),
398+
dim3 dimGrid(ceil_div(max_n, BLK_M),
399+
ceil_div(max_m, BLK_N),
400400
ibatch);
401401
const T* alpha_tmp = nullptr;
402402
if (alpha != nullptr)

source/module_hamilt_lcao/module_gint/temp_gint/kernel/gemm_tn_vbatch.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <cuda_runtime.h>
77
#include <stdio.h> // for fprintf and stderr
88

9-
#include "module_hamilt_lcao/module_gint/kernels/cuda/cuda_tools.cuh"
9+
#include "gint_helper.cuh"
1010
#include <functional>
1111

1212

@@ -420,8 +420,8 @@ void vbatched_gemm_tn_impl(int max_m,
420420
for (int i = 0; i < batchCount; i += max_batch_count)
421421
{
422422
const int ibatch = min(max_batch_count, batchCount - i);
423-
dim3 dimGrid(ceildiv(max_n, BLK_M),
424-
ceildiv(max_m, BLK_N),
423+
dim3 dimGrid(ceil_div(max_n, BLK_M),
424+
ceil_div(max_m, BLK_N),
425425
ibatch);
426426
const T* alpha_tmp = nullptr;
427427
if (alpha != nullptr)

source/module_hamilt_lcao/module_gint/temp_gint/kernel/gint_gpu_vars.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "module_cell/unitcell.h"
77
#include "module_cell/atom_spec.h"
88
#include "module_hamilt_lcao/module_gint/temp_gint/biggrid_info.h"
9-
#include "module_hamilt_lcao/module_gint/kernels/cuda/cuda_tools.cuh"
9+
#include "gint_helper.cuh"
1010
#include "module_hamilt_lcao/module_gint/kernels/cuda/gemm_selector.cuh"
1111

1212
namespace ModuleGint

source/module_hamilt_lcao/module_gint/temp_gint/kernel/gint_helper.cuh

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#include <cstdio>
23

34
// if exponent is an integer between 0 and 5 (the most common cases in gint) and
45
// and exp is a variable that cannot be determined at compile time (which means the compiler cannot optimize the code),
@@ -24,7 +25,7 @@ __forceinline__ __device__ T pow_int(const T base, const int exp)
2425
double result = std::pow(base, exp);
2526
return result;
2627
}
27-
};
28+
}
2829

2930
template<typename T>
3031
__forceinline__ __device__ T warpReduceSum(T val)
@@ -35,4 +36,40 @@ __forceinline__ __device__ T warpReduceSum(T val)
3536
val += __shfl_xor_sync(0xffffffff, val, 2, 32);
3637
val += __shfl_xor_sync(0xffffffff, val, 1, 32);
3738
return val;
38-
}
39+
}
40+
41+
inline int ceil_div(const int a, const int b)
42+
{
43+
return a / b + (a % b != 0 && (a ^ b) > 0);
44+
}
45+
46+
inline void check(cudaError_t result, char const *const func, const char *const file,
47+
int const line) {
48+
if (result) {
49+
fprintf(stderr, "CUDA error at %s:%d code=%d \"%s\" \n", file, line,
50+
static_cast<unsigned int>(result), func);
51+
exit(EXIT_FAILURE);
52+
}
53+
}
54+
55+
inline void __getLastCudaError(const char *file,
56+
const int line)
57+
{
58+
cudaError_t err = cudaGetLastError();
59+
60+
if (cudaSuccess != err) {
61+
fprintf(stderr,
62+
"%s(%i) : getLastCudaError() CUDA error :"
63+
" (%d) %s.\n",
64+
file, line, static_cast<int>(err),
65+
cudaGetErrorString(err));
66+
exit(EXIT_FAILURE);
67+
}
68+
}
69+
70+
// This will output the proper CUDA error strings in the event
71+
// that a CUDA host call returns an error
72+
#define checkCuda(val) check((val), #val, __FILE__, __LINE__)
73+
74+
// This will output the proper error string when calling cudaGetLastError
75+
#define checkCudaLastError() __getLastCudaError(__FILE__, __LINE__)

source/module_hamilt_lcao/module_gint/temp_gint/kernel/phi_operator_gpu.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include <cuda_runtime.h>
44

55
#include "module_hamilt_lcao/module_gint/temp_gint/batch_biggrid.h"
6-
#include "module_hamilt_lcao/module_gint/kernels/cuda/cuda_tools.cuh"
6+
#include "gint_helper.cuh"
77
#include "gint_gpu_vars.h"
88
#include "cuda_mem_wrapper.h"
99

source/module_hamilt_lcao/module_gint/temp_gint/kernel/set_const_mem.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "set_const_mem.cuh"
2-
#include "module_hamilt_lcao/module_gint/kernels/cuda/cuda_tools.cuh"
2+
#include "gint_helper.cuh"
33

44
__constant__ double ylmcoe_d[100];
55

0 commit comments

Comments
 (0)