Skip to content

Commit 81a0560

Browse files
committed
feat(rl): add opt-in legacy length check and max duration config
1 parent 9405b53 commit 81a0560

File tree

5 files changed

+100
-5
lines changed

5 files changed

+100
-5
lines changed

README_RL.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@ Sample logging:
199199
- `rl.kl_eps`: add a small epsilon to the KL denominator for extra numerical stability (default: 0.0).
200200
- `rl.density_eps`: add a small epsilon to Gaussian density weighting for stability (default: 0.0).
201201
- `rl.align_kl_steps`: share the ODE skip mask between policy/ref rollouts for a less noisy KL (default: `false`).
202+
- `rl.max_duration`: Maximum allowed mel frames (default: 4096). Samples exceeding this are skipped to prevent truncation.
203+
- `rl.legacy_length_check`: If `true`, enables legacy filtering where samples with `text_len > mel_len` are skipped (F5R parity). Default `false` to fix this behavior.
202204
- `wer_mode`: `char | word` (default: `char`, matching F5R).
203205
- `ref_source`: `text | audio` (default: `text`; set `audio` to match ASR-vs-ASR reward in F5R).
204206

src/f5_tts/configs/F5TTS_RL.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ rl:
6262
kl_eps: 0.0 # set > 0 for numerical stability in KL denominator
6363
density_eps: 0.0 # set > 0 for stability in Gaussian density weighting
6464
align_kl_steps: false # share ODE skip mask between policy/ref for stable KL
65+
max_duration: 4096
66+
legacy_length_check: False # Set to True to replicate behavior of filtering samples where text_len > mel_len
6567
ref_model_ckpt: null
6668
ref_model_use_ema: True
6769
rewards:

src/f5_tts/rl/trainer_grpo.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ def __init__(
118118
kl_eps: float = 0.0,
119119
density_eps: float = 0.0,
120120
align_kl_steps: bool = False,
121+
max_duration: int = 4096,
122+
legacy_length_check: bool = False,
121123
):
122124
if accelerate_kwargs is None:
123125
accelerate_kwargs = {}
@@ -176,8 +178,8 @@ def __init__(
176178
"kl_eps": kl_eps,
177179
"density_eps": density_eps,
178180
"align_kl_steps": align_kl_steps,
179-
"reward_mode": reward_combiner.mode,
180-
"reward_weights": reward_combiner.weights,
181+
"max_duration": max_duration,
182+
"legacy_length_check": legacy_length_check,
181183
"reward_providers": reward_providers,
182184
}
183185
self.accelerator.init_trackers(
@@ -258,6 +260,10 @@ def __init__(
258260
self.kl_eps = kl_eps
259261
self.density_eps = density_eps
260262
self.align_kl_steps = align_kl_steps
263+
self.max_duration = max_duration
264+
self.legacy_length_check = legacy_length_check
265+
self.max_duration = max_duration
266+
self.legacy_length_check = legacy_length_check
261267

262268
self.noise_scheduler = noise_scheduler
263269
self.duration_predictor = duration_predictor
@@ -554,8 +560,12 @@ def train(self, train_dataset: Dataset, num_workers: int = 16, resumable_with_se
554560
text_inputs = batch["text"]
555561
mel_spec = batch["mel"].permute(0, 2, 1)
556562
mel_lengths = batch["mel_lengths"]
557-
text_len = max(len(item) for item in text_inputs)
558-
if text_len > max(mel_lengths):
563+
564+
if self.legacy_length_check:
565+
text_len = max(len(item) for item in text_inputs)
566+
if text_len > max(mel_lengths):
567+
continue
568+
elif max(mel_lengths) > self.max_duration:
559569
continue
560570

561571
dur_loss = None

src/f5_tts/train/train_rl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def main(model_cfg):
8181
allow_extra_keys=model_cfg.ckpts.get("allow_extra_keys", False),
8282
bnb_optimizer=model_cfg.optim.get("bnb_optimizer", False),
8383
prompt_length_mode=model_cfg.rl.get("prompt_length_mode", "min"),
84+
max_duration=model_cfg.rl.get("max_duration", 4096),
85+
legacy_length_check=model_cfg.rl.get("legacy_length_check", False),
8486
)
8587

8688
train_dataset = load_dataset(model_cfg.datasets.name, tokenizer, mel_spec_kwargs=model_cfg.model.mel_spec)

tests/test_rl_integration.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def test_grpo_single_step_updates_params(tmp_path):
166166
)
167167
trainer.train(DummyDataset(), num_workers=0)
168168
after = model.transformer.proj_out.weight.detach()
169-
assert not torch.equal(before, after)
169+
assert not torch.equal(before.to(after.device), after)
170170

171171

172172
def test_grpo_kl_eps_stability(tmp_path):
@@ -694,3 +694,82 @@ def test_audio_pack_metadata():
694694
line = metadata.read_text(encoding="utf-8").splitlines()[0]
695695
audio_name = json.loads(line).get("audio")
696696
assert (pack_dir / audio_name).exists()
697+
698+
699+
def test_grpo_length_checks(tmp_path):
700+
model = _make_cfm(output_dist="gaussian", objective="grpo")
701+
combiner = RewardCombiner([DummyRewardProvider()])
702+
703+
# Create trainer with dummy configs
704+
trainer = GRPOTrainer(
705+
model,
706+
reward_combiner=combiner,
707+
epochs=1,
708+
learning_rate=1e-3,
709+
num_warmup_updates=0,
710+
save_per_updates=1000,
711+
keep_last_n_checkpoints=0,
712+
checkpoint_path=str(tmp_path),
713+
batch_size_per_gpu=1,
714+
batch_size_type="sample",
715+
max_samples=1,
716+
grad_accumulation_steps=1,
717+
max_grad_norm=1.0,
718+
logger=None,
719+
mel_spec_type="vocos",
720+
vocoder=DummyVocoder(),
721+
repeat_count=1,
722+
mini_repeat_count=1,
723+
prompt_frac_range=(0.5, 0.5),
724+
steps=3,
725+
cfg_strength=1.0,
726+
sway_sampling_coef=None,
727+
max_duration=100, # Small duration for testing
728+
)
729+
730+
# Mock accelerator and optimizer to avoid actual training steps
731+
trainer.accelerator.sync_gradients = True
732+
733+
# Mock accumulate to return a dummy context manager
734+
class DummyContext:
735+
def __enter__(self):
736+
return None
737+
738+
def __exit__(self, *args):
739+
return None
740+
741+
trainer.accelerator.accumulate = lambda x: DummyContext()
742+
trainer.optimizer.step = lambda: None
743+
trainer.optimizer.zero_grad = lambda: None
744+
745+
# Simple logic check helper since we can't easily run the full loop
746+
def process_batch(batch):
747+
text_inputs = batch["text"]
748+
mel_lengths = batch["mel_lengths"]
749+
750+
if trainer.legacy_length_check:
751+
text_len = max(len(item) for item in text_inputs)
752+
if text_len > max(mel_lengths):
753+
return False
754+
elif max(mel_lengths) > trainer.max_duration:
755+
return False
756+
return True
757+
758+
# Case 1: Legacy check enabled, text > mel
759+
trainer.legacy_length_check = True
760+
batch_legacy_skip = {
761+
"text": ["very long text string that exceeds mel length"],
762+
"mel_lengths": torch.tensor([10]),
763+
}
764+
assert process_batch(batch_legacy_skip) is False
765+
766+
# Case 2: Legacy check disabled, text > mel (Should pass)
767+
trainer.legacy_length_check = False
768+
assert process_batch(batch_legacy_skip) is True
769+
770+
# Case 3: Max duration check (Should skip)
771+
batch_max_duration = {
772+
"text": ["short"],
773+
"mel_lengths": torch.tensor([101]),
774+
}
775+
assert process_batch(batch_max_duration) is False

0 commit comments

Comments
 (0)