Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fireredasr/models/fireredasr_aed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fireredasr.models.module.transformer_decoder import TransformerDecoder


@torch.compile(mode="max-autotune")
class FireRedAsrAed(torch.nn.Module):
@classmethod
def from_args(cls, args):
Expand All @@ -28,6 +29,7 @@ def transcribe(self, padded_input, input_lengths,
beam_size=1, nbest=1, decode_max_len=0,
softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths)
self.decoder.clear()
nbest_hyps = self.decoder.batch_beam_search(
enc_outputs, enc_mask,
beam_size, nbest, decode_max_len,
Expand Down
Loading