Skip to content

Commit df1e911

Browse files
authored
[src] Fix: feature dim might differ from lda dim in CUDA ivector extractor (#4689)
* Feature dim might differ from lda dim in ivector extractor * Cosmetics renaming
1 parent d366a93 commit df1e911

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

src/cudafeat/feature-online-batched-ivector-cuda.cc

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
namespace kaldi {
2222
BatchedIvectorExtractorCuda::BatchedIvectorExtractorCuda(
23-
const OnlineIvectorExtractionConfig &config, int32_t chunk_size,
23+
const OnlineIvectorExtractionConfig &config,
24+
int32_t feat_dim, int32_t chunk_size,
2425
int32_t num_lanes, int32_t num_channels)
2526
: cmvn_(NULL),
27+
feat_dim_(feat_dim),
2628
chunk_size_(chunk_size),
2729
max_lanes_(num_lanes),
2830
num_channels_(num_channels) {
@@ -57,13 +59,14 @@ BatchedIvectorExtractorCuda::BatchedIvectorExtractorCuda(
5759
kUndefined);
5860
spliced_feats_.Resize(num_lanes * chunk_size, feat_dim_ * size, kUndefined);
5961
tmp_feats_.Resize(num_lanes * chunk_size, feat_dim_, kUndefined);
62+
lda_feats_.Resize(num_lanes * chunk_size, lda_dim_, kUndefined);
6063
posteriors_.Resize(num_lanes * chunk_size, num_gauss_, kUndefined);
6164

6265
gamma_.Resize(num_lanes * num_gauss_, kUndefined);
6366
gamma_stash_.Resize(num_channels * num_gauss_, kUndefined);
6467

65-
X_.Resize(num_lanes * num_gauss_, feat_dim_, kUndefined);
66-
X_stash_.Resize(num_channels * num_gauss_, feat_dim_, kUndefined);
68+
X_.Resize(num_lanes * num_gauss_, lda_dim_, kUndefined);
69+
X_stash_.Resize(num_channels * num_gauss_, lda_dim_, kUndefined);
6770

6871
linear_.Resize(num_lanes * ivector_dim_);
6972
sp_quadratic_.Resize(num_lanes, ivector_dim_ * (ivector_dim_ + 1) / 2);
@@ -142,14 +145,14 @@ void BatchedIvectorExtractorCuda::Read(
142145

143146
// compute derived variables
144147
ivector_dim_ = ie_M[0].NumCols();
145-
feat_dim_ = ie_M[0].NumRows();
148+
lda_dim_ = ie_M[0].NumRows();
146149

147-
ie_Sigma_inv_M_f_.Resize(num_gauss_ * feat_dim_, ivector_dim_, kUndefined);
150+
ie_Sigma_inv_M_f_.Resize(num_gauss_ * lda_dim_, ivector_dim_, kUndefined);
148151

149152
ie_U_.Resize(num_gauss_, ivector_dim_ * (ivector_dim_ + 1) / 2);
150153

151154
SpMatrix<float> tmp_sub_U(ivector_dim_);
152-
Matrix<float> tmp_Sigma_inv_M(feat_dim_, ivector_dim_);
155+
Matrix<float> tmp_Sigma_inv_M(lda_dim_, ivector_dim_);
153156
for (int32 i = 0; i < num_gauss_; i++) {
154157
// compute matrix ie_Sigma_inv_M[i]
155158
tmp_sub_U.AddMat2Sp(1, ie_M[i], kTrans, ie_Sigma_inv[i], 0);
@@ -160,7 +163,7 @@ void BatchedIvectorExtractorCuda::Read(
160163
tmp_Sigma_inv_M.AddSpMat(1, ie_Sigma_inv[i], ie_M[i], kNoTrans, 0);
161164

162165
// copy into global matrix
163-
CuSubMatrix<float> window(ie_Sigma_inv_M_f_, i * feat_dim_, feat_dim_, 0,
166+
CuSubMatrix<float> window(ie_Sigma_inv_M_f_, i * lda_dim_, lda_dim_, 0,
164167
ivector_dim_);
165168
window.CopyFromMat(tmp_Sigma_inv_M);
166169
}
@@ -183,11 +186,11 @@ void BatchedIvectorExtractorCuda::GetIvectors(
183186
// Stash feats
184187
StashFeats(tmp_feats_, &norm_feats_stash_, lanes, num_lanes);
185188

186-
// LDA transform spliced feats back into tmp_feats
187-
LDATransform(spliced_feats_, &tmp_feats_, lanes, num_lanes);
189+
// LDA transform spliced feats
190+
LDATransform(spliced_feats_, &lda_feats_, lanes, num_lanes);
188191

189192
// compute posteriors based normalized lda feats
190-
ComputePosteriors(tmp_feats_, lanes, num_lanes);
193+
ComputePosteriors(lda_feats_, lanes, num_lanes);
191194
}
192195

193196
// non-normalized pipeline
@@ -198,20 +201,20 @@ void BatchedIvectorExtractorCuda::GetIvectors(
198201
// Stash feats
199202
StashFeats(feats, &feats_stash_, lanes, num_lanes);
200203

201-
// LDA transform spliced feats back into tmp_feats
202-
LDATransform(spliced_feats_, &tmp_feats_, lanes, num_lanes);
204+
// LDA transform spliced feats
205+
LDATransform(spliced_feats_, &lda_feats_, lanes, num_lanes);
203206
}
204207

205208
// compute ivector stats
206-
ComputeIvectorStats(tmp_feats_, lanes, num_lanes);
209+
ComputeIvectorStats(lda_feats_, lanes, num_lanes);
207210

208211
// compute ivectors for the stats
209212
ComputeIvectorsFromStats(ivectors, lanes, num_lanes);
210213
}
211214

212215
void BatchedIvectorExtractorCuda::InitializeChannels(const LaneDesc *lanes,
213216
int32_t num_lanes) {
214-
initialize_channels(num_gauss_, feat_dim_, gamma_stash_.Data(), num_gauss_,
217+
initialize_channels(num_gauss_, lda_dim_, gamma_stash_.Data(), num_gauss_,
215218
X_stash_.Data(), X_stash_.Stride(),
216219
X_stash_.Stride() * num_gauss_, lanes, num_lanes);
217220
}
@@ -273,7 +276,7 @@ void BatchedIvectorExtractorCuda::ComputePosteriors(CuMatrix<BaseFloat> &feats,
273276
posteriors_.AddMatMat(1.0, feats, kNoTrans, ubm_means_inv_vars_, kTrans, 1.0);
274277

275278
// square feats
276-
square_batched_matrix(chunk_size_, feat_dim_, feats.Data(), feats.Stride(),
279+
square_batched_matrix(chunk_size_, lda_dim_, feats.Data(), feats.Stride(),
277280
feats.Stride() * chunk_size_, feats.Data(),
278281
feats.Stride(), feats.Stride() * chunk_size_, lanes,
279282
num_lanes);
@@ -300,7 +303,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
300303
num_gauss_, info_.posterior_scale, lanes, num_lanes);
301304

302305
#if CUDA_VERSION >= 9010
303-
int32_t m = feat_dim_;
306+
int32_t m = lda_dim_;
304307
int32_t n = num_gauss_;
305308
int32_t k = chunk_size_;
306309
float alpha = info_.posterior_scale;
@@ -323,7 +326,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorStats(
323326
#endif
324327

325328
apply_and_update_stash(
326-
num_gauss_, feat_dim_, gamma_.Data(), gamma_stash_.Data(), num_gauss_,
329+
num_gauss_, lda_dim_, gamma_.Data(), gamma_stash_.Data(), num_gauss_,
327330
X_.Data(), X_.Stride(), X_.Stride() * num_gauss_, X_stash_.Data(),
328331
X_stash_.Stride(), X_stash_.Stride() * num_gauss_, lanes, num_lanes);
329332
}
@@ -336,7 +339,7 @@ void BatchedIvectorExtractorCuda::ComputeIvectorsFromStats(
336339
// need to set this term to zero because batched_compute_linear_term
337340
// uses atomics with a +=
338341
linear_.SetZero();
339-
batched_compute_linear_term(num_gauss_, feat_dim_, ivector_dim_,
342+
batched_compute_linear_term(num_gauss_, lda_dim_, ivector_dim_,
340343
ie_Sigma_inv_M_f_.Data(),
341344
ie_Sigma_inv_M_f_.Stride(), X_.Data(),
342345
X_.Stride(), X_.Stride() * num_gauss_,

src/cudafeat/feature-online-batched-ivector-cuda.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace kaldi {
3030
class BatchedIvectorExtractorCuda {
3131
public:
3232
BatchedIvectorExtractorCuda(const OnlineIvectorExtractionConfig &config,
33+
int32_t feat_dim,
3334
int32_t chunk_size, int32_t num_lanes,
3435
int32_t num_channels);
3536
~BatchedIvectorExtractorCuda();
@@ -58,6 +59,7 @@ class BatchedIvectorExtractorCuda {
5859
int32_t num_lanes);
5960

6061
int32 FeatDim() const { return feat_dim_; }
62+
int32 LdaDim() const { return lda_dim_; }
6163
int32 IvectorDim() const { return ivector_dim_; }
6264
int32 NumGauss() const { return num_gauss_; }
6365

@@ -106,6 +108,7 @@ class BatchedIvectorExtractorCuda {
106108
CudaOnlineCmvnState naive_cmvn_state_;
107109
CudaOnlineBatchedCmvn *cmvn_;
108110
int32_t feat_dim_;
111+
int32_t lda_dim_;
109112
int32_t ivector_dim_;
110113
int32_t num_gauss_;
111114

@@ -121,8 +124,9 @@ class BatchedIvectorExtractorCuda {
121124
CuMatrix<BaseFloat> ie_Sigma_inv_M_f_;
122125

123126
// temporary memory unique per batch element
124-
CuMatrix<BaseFloat> spliced_feats_;
125127
CuMatrix<BaseFloat> tmp_feats_;
128+
CuMatrix<BaseFloat> spliced_feats_;
129+
CuMatrix<BaseFloat> lda_feats_;
126130
CuMatrix<BaseFloat> posteriors_;
127131
CuMatrix<BaseFloat> X_;
128132
CuVector<BaseFloat> gamma_;

src/cudafeat/online-batched-feature-pipeline-cuda.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ OnlineBatchedFeaturePipelineCuda::OnlineBatchedFeaturePipelineCuda(
8787
info_.ivector_extractor_info.Init(ivector_extraction_opts);
8888

8989
ivector_ = new BatchedIvectorExtractorCuda(ivector_extraction_opts,
90+
FeatureDim(),
9091
max_chunk_size_frames_,
9192
max_lanes_, num_channels_);
9293
}

0 commit comments

Comments
 (0)