Skip to content

Commit fad3cad

Browse files
committed
Clean up decoders
Signed-off-by: Vladimir Bataev <[email protected]>
1 parent cc0c3a1 commit fad3cad

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

nemo/collections/asr/parts/submodules/transducer_decoding/rnnt_label_looping.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,10 @@ def __init__(
233233
self.max_symbols = max_symbols_per_step
234234
self.preserve_alignments = preserve_alignments
235235
self.preserve_frame_confidence = preserve_frame_confidence
236-
self.allow_cuda_graphs = allow_cuda_graphs
237236
self._SOS = self._blank_index
238237
self._init_confidence_method(confidence_method_cfg=confidence_method_cfg)
239238
assert self._SOS == self._blank_index # "blank as pad" algorithm only
240239

241-
self.state = None
242-
self.full_graph = None
243-
self.separate_graphs = None
244-
245-
self.cuda_graphs_mode = None
246-
self.cuda_graphs_allow_fallback = True
247-
self.maybe_enable_cuda_graphs()
248-
249240
self.fusion_models = fusion_models or []
250241
self.fusion_models_alpha = fusion_models_alpha or []
251242

@@ -254,6 +245,25 @@ def __init__(
254245
if enable_per_stream_biasing
255246
else None
256247
)
248+
if allow_cuda_graphs:
249+
for fusion_model in self._all_fusion_models():
250+
if not fusion_model.compatible_with_cuda_graphs():
251+
logging.warning(
252+
"Fusion model used that is incompatible with CUDA graphs."
253+
" Switching off CUDA graphs, decoding may be slow."
254+
)
255+
allow_cuda_graphs = False
256+
break
257+
258+
self.allow_cuda_graphs = allow_cuda_graphs
259+
260+
self.state = None
261+
self.full_graph = None
262+
self.separate_graphs = None
263+
264+
self.cuda_graphs_mode = None
265+
self.cuda_graphs_allow_fallback = True
266+
self.maybe_enable_cuda_graphs()
257267

258268
@property
259269
def per_stream_biasing_enabled(self):

nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,21 +257,12 @@ def __init__(
257257
self.preserve_alignments = preserve_alignments
258258
self.preserve_frame_confidence = preserve_frame_confidence
259259
self.preserve_alignments = preserve_alignments or preserve_frame_confidence
260-
self.allow_cuda_graphs = allow_cuda_graphs
261260
self.include_duration = include_duration
262261
self.include_duration_confidence = include_duration_confidence
263262
self._SOS = self._blank_index
264263
self._init_confidence_method(confidence_method_cfg=confidence_method_cfg)
265264
assert self._SOS == self._blank_index # "blank as pad" algorithm only
266265

267-
self.state = None
268-
self.full_graph = None
269-
self.separate_graphs = None
270-
271-
self.cuda_graphs_mode = None
272-
self.cuda_graphs_allow_fallback = True
273-
self.maybe_enable_cuda_graphs()
274-
275266
self.fusion_models = fusion_models or []
276267
self.fusion_models_alpha = fusion_models_alpha or []
277268

@@ -281,6 +272,26 @@ def __init__(
281272
else None
282273
)
283274

275+
if allow_cuda_graphs:
276+
for fusion_model in self._all_fusion_models():
277+
if not fusion_model.compatible_with_cuda_graphs():
278+
logging.warning(
279+
"Fusion model used that is incompatible with CUDA graphs."
280+
" Switching off CUDA graphs, decoding may be slow."
281+
)
282+
allow_cuda_graphs = False
283+
break
284+
285+
self.allow_cuda_graphs = allow_cuda_graphs
286+
287+
self.state = None
288+
self.full_graph = None
289+
self.separate_graphs = None
290+
291+
self.cuda_graphs_mode = None
292+
self.cuda_graphs_allow_fallback = True
293+
self.maybe_enable_cuda_graphs()
294+
284295
@property
285296
def per_stream_biasing_enabled(self):
286297
return self.biasing_multi_model is not None

0 commit comments

Comments
 (0)