Skip to content

Commit d0937dc

Browse files
authored
[src] Initialize CuDevice member; cleanup cu-common.{h,cc} (#4567)
* Initialize CuDevice::curand_handle_ to NULL in constructor. * Remove stray include guards from cu-common.cc. * Rearrange #include's per coding guidelines.
1 parent 1714f6c commit d0937dc

File tree

3 files changed

+17
-21
lines changed

3 files changed

+17
-21
lines changed

src/cudamatrix/cu-common.cc

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,18 @@
1818
// See the Apache 2 License for the specific language governing permissions and
1919
// limitations under the License.
2020

21-
#ifndef KALDI_CUDAMATRIX_COMMON_H_
22-
#define KALDI_CUDAMATRIX_COMMON_H_
21+
#if HAVE_CUDA
2322

24-
// This file contains some #includes, forward declarations
25-
// and typedefs that are needed by all the main header
26-
// files in this directory.
27-
#include <mutex>
28-
#include "base/kaldi-common.h"
29-
#include "matrix/kaldi-blas.h"
30-
#include "cudamatrix/cu-device.h"
3123
#include "cudamatrix/cu-common.h"
24+
25+
#include <cuda.h>
26+
27+
#include "base/kaldi-common.h"
3228
#include "cudamatrix/cu-matrixdim.h"
29+
#include "matrix/kaldi-blas.h"
3330

3431
namespace kaldi {
3532

36-
#if HAVE_CUDA == 1
37-
3833
#ifdef USE_NVTX
3934
NvtxTracer::NvtxTracer(const char* name) {
4035
const uint32_t colors[] = { 0xff00ff00, 0xff0000ff, 0xffffff00, 0xffff00ff, 0xff00ffff, 0xffff0000, 0xffffffff };
@@ -91,6 +86,7 @@ void GetBlockSizesForSimpleMatrixOperation(int32 num_rows,
9186
}
9287

9388
const char* cublasGetStatusString(cublasStatus_t status) {
89+
// Defined in CUDA include file: cublas.h or cublas_api.h
9490
switch(status) {
9591
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
9692
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
@@ -108,6 +104,7 @@ const char* cublasGetStatusString(cublasStatus_t status) {
108104

109105
const char* cusparseGetStatusString(cusparseStatus_t status) {
110106
// detail info come from http://docs.nvidia.com/cuda/cusparse/index.html#cusparsestatust
107+
// Defined in CUDA include file: cusparse.h
111108
switch(status) {
112109
case CUSPARSE_STATUS_SUCCESS: return "CUSPARSE_STATUS_SUCCESS";
113110
case CUSPARSE_STATUS_NOT_INITIALIZED: return "CUSPARSE_STATUS_NOT_INITIALIZED";
@@ -129,6 +126,7 @@ const char* cusparseGetStatusString(cusparseStatus_t status) {
129126

130127
const char* curandGetStatusString(curandStatus_t status) {
131128
// detail info come from http://docs.nvidia.com/cuda/curand/group__HOST.html
129+
// Defined in CUDA include file: curand.h
132130
switch(status) {
133131
case CURAND_STATUS_SUCCESS: return "CURAND_STATUS_SUCCESS";
134132
case CURAND_STATUS_VERSION_MISMATCH: return "CURAND_STATUS_VERSION_MISMATCH";
@@ -146,9 +144,7 @@ const char* curandGetStatusString(curandStatus_t status) {
146144
}
147145
return "CURAND_STATUS_UNKNOWN_ERROR";
148146
}
149-
#endif
150-
151-
} // namespace
152147

148+
} // namespace kaldi
153149

154-
#endif // KALDI_CUDAMATRIX_COMMON_H_
150+
#endif // HAVE_CUDA

src/cudamatrix/cu-common.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
#include "cudamatrix/cu-matrixdim.h" // for CU1DBLOCK and CU2DBLOCK
3030
#include "matrix/matrix-common.h"
3131

32-
#if HAVE_CUDA == 1
32+
#if HAVE_CUDA
33+
3334
#include <cublas_v2.h>
3435
#include <cuda_runtime_api.h>
3536
#include <curand.h>
@@ -136,12 +137,11 @@ const char* cusparseGetStatusString(cusparseStatus_t status);
136137

137138
/** This is analogous to the CUDA function cudaGetErrorString(). **/
138139
const char* curandGetStatusString(curandStatus_t status);
139-
}
140140

141-
#else
142-
namespace kaldi {
141+
} // namespace kaldi
142+
143+
#else // HAVE CUDA
143144
#define NVTX_RANGE(name)
144-
};
145145
#endif // HAVE_CUDA
146146

147147
namespace kaldi {
@@ -160,7 +160,6 @@ template<typename Real> class CuSparseMatrix;
160160

161161
template<typename Real> class CuBlockMatrix; // this has no non-CU counterpart.
162162

163-
164163
} // namespace kaldi
165164

166165
#endif // KALDI_CUDAMATRIX_CU_COMMON_H_

src/cudamatrix/cu-device.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ CuDevice::CuDevice():
598598
device_id_copy_(-1),
599599
cublas_handle_(NULL),
600600
cusparse_handle_(NULL),
601+
curand_handle_(NULL),
601602
cusolverdn_handle_(NULL) {
602603
}
603604

0 commit comments

Comments
 (0)