1818// See the Apache 2 License for the specific language governing permissions and
1919// limitations under the License.
2020
21- #ifndef KALDI_CUDAMATRIX_COMMON_H_
22- #define KALDI_CUDAMATRIX_COMMON_H_
21+ #if HAVE_CUDA
2322
24- // This file contains some #includes, forward declarations
25- // and typedefs that are needed by all the main header
26- // files in this directory.
27- #include < mutex>
28- #include " base/kaldi-common.h"
29- #include " matrix/kaldi-blas.h"
30- #include " cudamatrix/cu-device.h"
3123#include " cudamatrix/cu-common.h"
24+
25+ #include < cuda.h>
26+
27+ #include " base/kaldi-common.h"
3228#include " cudamatrix/cu-matrixdim.h"
29+ #include " matrix/kaldi-blas.h"
3330
3431namespace kaldi {
3532
36- #if HAVE_CUDA == 1
37-
3833#ifdef USE_NVTX
3934NvtxTracer::NvtxTracer (const char * name) {
4035 const uint32_t colors[] = { 0xff00ff00 , 0xff0000ff , 0xffffff00 , 0xffff00ff , 0xff00ffff , 0xffff0000 , 0xffffffff };
@@ -91,6 +86,7 @@ void GetBlockSizesForSimpleMatrixOperation(int32 num_rows,
9186}
9287
9388const char * cublasGetStatusString (cublasStatus_t status) {
89+ // Defined in CUDA include file: cublas.h or cublas_api.h
9490 switch (status) {
9591 case CUBLAS_STATUS_SUCCESS: return " CUBLAS_STATUS_SUCCESS" ;
9692 case CUBLAS_STATUS_NOT_INITIALIZED: return " CUBLAS_STATUS_NOT_INITIALIZED" ;
@@ -108,6 +104,7 @@ const char* cublasGetStatusString(cublasStatus_t status) {
108104
109105const char * cusparseGetStatusString (cusparseStatus_t status) {
110106 // detail info come from http://docs.nvidia.com/cuda/cusparse/index.html#cusparsestatust
107+ // Defined in CUDA include file: cusparse.h
111108 switch (status) {
112109 case CUSPARSE_STATUS_SUCCESS: return " CUSPARSE_STATUS_SUCCESS" ;
113110 case CUSPARSE_STATUS_NOT_INITIALIZED: return " CUSPARSE_STATUS_NOT_INITIALIZED" ;
@@ -129,6 +126,7 @@ const char* cusparseGetStatusString(cusparseStatus_t status) {
129126
130127const char * curandGetStatusString (curandStatus_t status) {
131128 // detail info come from http://docs.nvidia.com/cuda/curand/group__HOST.html
129+ // Defined in CUDA include file: curand.h
132130 switch (status) {
133131 case CURAND_STATUS_SUCCESS: return " CURAND_STATUS_SUCCESS" ;
134132 case CURAND_STATUS_VERSION_MISMATCH: return " CURAND_STATUS_VERSION_MISMATCH" ;
@@ -146,9 +144,7 @@ const char* curandGetStatusString(curandStatus_t status) {
146144 }
147145 return " CURAND_STATUS_UNKNOWN_ERROR" ;
148146}
149- #endif
150-
151- } // namespace
152147
148+ } // namespace kaldi
153149
154- #endif // KALDI_CUDAMATRIX_COMMON_H_
150+ #endif // HAVE_CUDA
0 commit comments