Skip to content

Commit 3a8a024

Browse files
committed
update tests
Signed-off-by: Jason <[email protected]>
1 parent 020fb42 commit 3a8a024

File tree

5 files changed

+12
-12
lines changed

5 files changed

+12
-12
lines changed

tests/collections/audio/test_audio_models_flow_matching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def flow_matching_base_config(request):
7979
'time_max': flow['time_max'],
8080
}
8181

82-
loss = {'_target_': 'nemo.collections.audio.losses.MSELoss', 'ndim': 4}
82+
loss = {'_target_': 'nemo.collections.audio.losses.audio.MSELoss', 'ndim': 4}
8383

8484
estimator = {
8585
'_target_': 'nemo.collections.audio.parts.submodules.transformerunet.SpectrogramTransformerUNet',

tests/collections/audio/test_audio_models_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def mask_model_rnn_params():
111111
}
112112

113113
loss = {
114-
'_target_': 'nemo.collections.audio.losses.SDRLoss',
114+
'_target_': 'nemo.collections.audio.losses.audio.SDRLoss',
115115
'scale_invariant': True,
116116
}
117117

@@ -212,7 +212,7 @@ def mask_model_flexarray():
212212
}
213213

214214
loss = {
215-
'_target_': 'nemo.collections.audio.losses.SDRLoss',
215+
'_target_': 'nemo.collections.audio.losses.audio.SDRLoss',
216216
'scale_invariant': True,
217217
}
218218

tests/collections/audio/test_audio_models_predictive.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def predictive_model_ncsn():
111111
}
112112

113113
loss = {
114-
'_target_': 'nemo.collections.audio.losses.MSELoss', # computed in the time domain
114+
'_target_': 'nemo.collections.audio.losses.audio.MSELoss', # computed in the time domain
115115
}
116116

117117
model_config = DictConfig(
@@ -183,7 +183,7 @@ def predictive_model_conformer():
183183
}
184184

185185
loss = {
186-
'_target_': 'nemo.collections.audio.losses.MSELoss', # computed in the time domain
186+
'_target_': 'nemo.collections.audio.losses.audio.MSELoss', # computed in the time domain
187187
}
188188

189189
model_config = DictConfig(
@@ -255,7 +255,7 @@ def predictive_model_streaming_conformer():
255255
}
256256

257257
loss = {
258-
'_target_': 'nemo.collections.audio.losses.MSELoss', # computed in the time domain
258+
'_target_': 'nemo.collections.audio.losses.audio.MSELoss', # computed in the time domain
259259
}
260260

261261
model_config = DictConfig(
@@ -318,7 +318,7 @@ def predictive_model_transformer_unet_params_base():
318318
}
319319

320320
loss = {
321-
'_target_': 'nemo.collections.audio.losses.MSELoss', # computed in the time domain
321+
'_target_': 'nemo.collections.audio.losses.audio.MSELoss', # computed in the time domain
322322
}
323323

324324
model_config = DictConfig(
@@ -384,7 +384,7 @@ def predictive_model_conformer_unet():
384384
}
385385

386386
loss = {
387-
'_target_': 'nemo.collections.audio.losses.MSELoss', # computed in the time domain
387+
'_target_': 'nemo.collections.audio.losses.audio.MSELoss', # computed in the time domain
388388
}
389389

390390
model_config = DictConfig(
@@ -456,7 +456,7 @@ def predictive_model_streaming_conformer_unet():
456456
}
457457

458458
loss = {
459-
'_target_': 'nemo.collections.audio.losses.MSELoss', # computed in the time domain
459+
'_target_': 'nemo.collections.audio.losses.audio.MSELoss', # computed in the time domain
460460
}
461461

462462
model_config = DictConfig(

tests/collections/audio/test_audio_models_schroedinger_bridge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ def schroedinger_bridge_model_ncsn_params():
112112
'pad_dimension_to': 0, # no padding in the frequency dimension
113113
}
114114

115-
loss_encoded = {'_target_': 'nemo.collections.audio.losses.MSELoss', 'ndim': 4} # computed in the time domain
115+
loss_encoded = {'_target_': 'nemo.collections.audio.losses.audio.MSELoss', 'ndim': 4} # computed in the time domain
116116

117-
loss_time = {'_target_': 'nemo.collections.audio.losses.MAELoss'}
117+
loss_time = {'_target_': 'nemo.collections.audio.losses.audio.MAELoss'}
118118

119119
noise_schedule = {
120120
'_target_': 'nemo.collections.audio.parts.submodules.schroedinger_bridge.SBNoiseScheduleVE',

tests/collections/audio/test_audio_models_score_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def score_based_base_config():
8787
'snr': 0.5,
8888
}
8989

90-
loss = {'_target_': 'nemo.collections.audio.losses.MSELoss', 'ndim': 4}
90+
loss = {'_target_': 'nemo.collections.audio.losses.audio.MSELoss', 'ndim': 4}
9191

9292
trainer = {
9393
'max_epochs': -1,

0 commit comments

Comments
 (0)