@@ -26,6 +26,11 @@ BatchedIvectorExtractorCuda::BatchedIvectorExtractorCuda(
2626 chunk_size_ (chunk_size),
2727 max_lanes_(num_lanes),
2828 num_channels_(num_channels) {
29+ #if CUDA_VERSION < 9010
30+ // some components require newer cuda versions. If you see this error
31+ // upgrade to a more recent CUDA version.
32+ KALDI_ERR << " BatchedIvectorExtractorCuda requires CUDA 9.1 or newer." ;
33+ #endif
2934 info_.Init (config);
3035 Read (config);
3136
@@ -290,6 +295,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
290295 posteriors_.Stride () * chunk_size_, gamma_.Data (),
291296 num_gauss_, info_.posterior_scale , lanes, num_lanes);
292297
298+ #if CUDA_VERSION >= 9010
293299 int32_t m = feat_dim_;
294300 int32_t n = num_gauss_;
295301 int32_t k = chunk_size_;
@@ -310,6 +316,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
310316 GetCublasHandle (), CUBLAS_OP_N, CUBLAS_OP_T, m, n, k, &alpha, A,
311317 CUDA_R_32F, lda, strideA, B, CUDA_R_32F, ldb, strideB, &beta, C,
312318 CUDA_R_32F, ldc, strideC, num_lanes, CUDA_R_32F, CUBLAS_GEMM_DEFAULT))
319+ #endif
313320
314321 apply_and_update_stash (
315322 num_gauss_, feat_dim_, gamma_.Data (), gamma_stash_.Data (), num_gauss_,
@@ -354,9 +361,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorsFromStats(
354361 ivector_dim_, lanes, num_lanes);
355362
356363#if CUDA_VERSION >= 9010
357-
358364 int nrhs = 1 ;
359-
360365 // perform factorization in batched
361366 CUSOLVER_SAFE_CALL (cusolverDnSpotrfBatched (
362367 GetCusolverDnHandle (), CUBLAS_FILL_MODE_LOWER, ivector_dim_, quad_array_,
@@ -367,18 +372,10 @@ void BatchedIvectorExtractorCuda::ComputeIvectorsFromStats(
367372 GetCusolverDnHandle (), CUBLAS_FILL_MODE_LOWER, ivector_dim_, nrhs,
368373 quad_array_, ivector_dim_, ivec_array_, ivector_dim_, d_infoArray_,
369374 num_lanes));
375+ #endif
370376
371377 // cusolver solves in place. Ivectors are now in linear_
372378
373- #else
374- // We could make a fallback if necessary. This would likely just loop
375- // over each matrix and call Invert not batched. This would be very slow and
376- // throwing an error is probably better force people to use a more recent
377- // version of CUDA.
378- KALDI_ERR << " Online Ivectors in CUDA is not supported by your CUDA version. "
379- << " Upgrade to CUDA 9.1 or later" ;
380- #endif
381-
382379 // Create a submatrix which points to the first element of each ivector
383380 CuSubMatrix<BaseFloat> ivector0 (linear_.Data (), num_lanes, 1 , ivector_dim_);
384381 // remove prior
0 commit comments