Skip to content

Commit f55bc9a

Browse files
authored
[src] Build fixes when CUDA version is less than 9.1 (#3901)
1 parent a4f2f34 commit f55bc9a

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

src/cudafeat/feature-online-batched-ivector-cuda.cc

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ BatchedIvectorExtractorCuda::BatchedIvectorExtractorCuda(
2626
chunk_size_(chunk_size),
2727
max_lanes_(num_lanes),
2828
num_channels_(num_channels) {
29+
#if CUDA_VERSION < 9010
30+
// some components require newer cuda versions. If you see this error
31+
// upgrade to a more recent CUDA version.
32+
KALDI_ERR << "BatchedIvectorExtractorCuda requires CUDA 9.1 or newer.";
33+
#endif
2934
info_.Init(config);
3035
Read(config);
3136

@@ -290,6 +295,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
290295
posteriors_.Stride() * chunk_size_, gamma_.Data(),
291296
num_gauss_, info_.posterior_scale, lanes, num_lanes);
292297

298+
#if CUDA_VERSION >= 9010
293299
int32_t m = feat_dim_;
294300
int32_t n = num_gauss_;
295301
int32_t k = chunk_size_;
@@ -310,6 +316,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
310316
GetCublasHandle(), CUBLAS_OP_N, CUBLAS_OP_T, m, n, k, &alpha, A,
311317
CUDA_R_32F, lda, strideA, B, CUDA_R_32F, ldb, strideB, &beta, C,
312318
CUDA_R_32F, ldc, strideC, num_lanes, CUDA_R_32F, CUBLAS_GEMM_DEFAULT))
319+
#endif
313320

314321
apply_and_update_stash(
315322
num_gauss_, feat_dim_, gamma_.Data(), gamma_stash_.Data(), num_gauss_,
@@ -354,9 +361,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorsFromStats(
354361
ivector_dim_, lanes, num_lanes);
355362

356363
#if CUDA_VERSION >= 9010
357-
358364
int nrhs = 1;
359-
360365
// perform factorization in batched
361366
CUSOLVER_SAFE_CALL(cusolverDnSpotrfBatched(
362367
GetCusolverDnHandle(), CUBLAS_FILL_MODE_LOWER, ivector_dim_, quad_array_,
@@ -367,18 +372,10 @@ void BatchedIvectorExtractorCuda::ComputeIvectorsFromStats(
367372
GetCusolverDnHandle(), CUBLAS_FILL_MODE_LOWER, ivector_dim_, nrhs,
368373
quad_array_, ivector_dim_, ivec_array_, ivector_dim_, d_infoArray_,
369374
num_lanes));
375+
#endif
370376

371377
// cusolver solves in place. Ivectors are now in linear_
372378

373-
#else
374-
// We could make a fallback if necessary. This would likely just loop
375-
// over each matrix and call Invert not batched. This would be very slow and
376-
// throwing an error is probably better force people to use a more recent
377-
// version of CUDA.
378-
KALDI_ERR << "Online Ivectors in CUDA is not supported by your CUDA version. "
379-
<< "Upgrade to CUDA 9.1 or later";
380-
#endif
381-
382379
// Create a submatrix which points to the first element of each ivector
383380
CuSubMatrix<BaseFloat> ivector0(linear_.Data(), num_lanes, 1, ivector_dim_);
384381
// remove prior

0 commit comments

Comments
 (0)