|
15 | 15 |
|
16 | 16 | from configparser import ConfigParser |
17 | 17 | from dataclasses import dataclass |
18 | | -from typing import List, Optional, Union |
| 18 | +from typing import Annotated, List, Optional, Tuple, Union |
19 | 19 | from unittest.mock import Mock, patch |
20 | 20 |
|
21 | 21 | import fiddle as fdl |
@@ -439,9 +439,9 @@ class Model: |
439 | 439 | class Optimizer: |
440 | 440 | """Dummy optimizer config""" |
441 | 441 |
|
442 | | - learning_rate: float |
443 | | - weight_decay: float |
444 | | - betas: List[float] |
| 442 | + learning_rate: float = 0.001 |
| 443 | + weight_decay: float = 1e-5 |
| 444 | + betas: Tuple[float, float] = (0.9, 0.999) |
445 | 445 |
|
446 | 446 |
|
447 | 447 | @run.cli.factory |
@@ -506,6 +506,22 @@ def train_model( |
506 | 506 | return {"model": model, "optimizer": optimizer, "epochs": epochs, "batch_size": batch_size} |
507 | 507 |
|
508 | 508 |
|
| 509 | +@run.cli.entrypoint( |
| 510 | + namespace="my_llm", |
| 511 | + skip_confirmation=True, |
| 512 | +) |
| 513 | +def train_model_default_optimizer( |
| 514 | + model: Model, |
| 515 | + optimizer: Annotated[Optional[Optimizer], run.Config[Optimizer]] = None, |
| 516 | + epochs: int = 10, |
| 517 | + batch_size: int = 32, |
| 518 | +): |
| 519 | + if optimizer is None: |
| 520 | + optimizer = Optimizer() |
| 521 | + |
| 522 | + return train_model(model, optimizer, epochs, batch_size) |
| 523 | + |
| 524 | + |
509 | 525 | @run.cli.factory(target=train_model) |
510 | 526 | def custom_defaults() -> run.Partial["train_model"]: |
511 | 527 | return run.Partial( |
@@ -585,6 +601,37 @@ def test_with_defaults(self, runner, app): |
585 | 601 | for i in range(1, 31): |
586 | 602 | assert f"Epoch {i}/30" in output |
587 | 603 |
|
| 604 | + def test_with_defaults_no_optimizer(self, runner, app): |
| 605 | + # Test CLI execution with default factory |
| 606 | + result = runner.invoke( |
| 607 | + app, |
| 608 | + [ |
| 609 | + "my_llm", |
| 610 | + "train_model_default_optimizer", |
| 611 | + "model=my_model(hidden_size=1024)", |
| 612 | + "epochs=30", |
| 613 | + "run.skip_confirmation=True", |
| 614 | + ], |
| 615 | + env={"INCLUDE_WORKSPACE_FILE": "false"}, |
| 616 | + ) |
| 617 | + assert result.exit_code == 0 |
| 618 | + |
| 619 | + # Parse the output to check the values |
| 620 | + output = result.stdout |
| 621 | + assert "Training model with the following configuration:" in output |
| 622 | + assert "Model: Model(hidden_size=1024, num_layers=3, activation='relu')" in output |
| 623 | + assert ( |
| 624 | + "Optimizer: Optimizer(learning_rate=0.001, weight_decay=1e-05, betas=(0.9, 0.999))" |
| 625 | + in output |
| 626 | + ) |
| 627 | + assert "Epochs: 30" in output |
| 628 | + assert "Batch size: 32" in output |
| 629 | + assert "Training completed!" in output |
| 630 | + |
| 631 | + # Check that all epochs were simulated |
| 632 | + for i in range(1, 31): |
| 633 | + assert f"Epoch {i}/30" in output |
| 634 | + |
588 | 635 | def test_experiment_entrypoint(self): |
589 | 636 | def dummy_pretrain(log_dir: str): |
590 | 637 | pass |
|
0 commit comments