- 
                Notifications
    
You must be signed in to change notification settings  - Fork 47
 
[RFC] - Config is code #512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # baseline.yaml shows 5 key config patterns | ||
| 
     | 
||
| output_dir: /tmp/torchtune/llama3_2_1B/full | ||
| 
     | 
||
| # PATTERN 1: Simple Component Instantiation | ||
| tokenizer: | ||
| _target_: mock.llama3_tokenizer | ||
| path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model | ||
| 
     | 
||
| # PATTERN 2: Component with Nested Instantiation | ||
| model: | ||
| _target_: mock.llama3_2_1b | ||
| # Nested component: attention config | ||
| attn_config: | ||
| _target_: mock.MultiHeadAttention | ||
| num_heads: 32 | ||
| 
     | 
||
| # PATTERN 3: Component Needing Runtime Args (Partial) | ||
| optimizer: | ||
| _target_: torch.optim.AdamW | ||
| lr: 2e-5 | ||
| _partial_: true | ||
| # params: None #will be passed at instantiation time (not known now) | ||
| 
     | 
||
| # PATTERN 4: Non-Instantiated Config Block (Plain Data) | ||
| data_args: | ||
| batch_size: 4 | ||
| shuffle: True | ||
| 
     | 
||
| # PATTERN 5: Plain Top-Level Hyperparameters | ||
| # Training params | ||
| epochs: 1 | ||
| gradient_accumulation_steps: 8 | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| 
     | 
||
| """ | ||
| Dataclass config with inner Config classes. | ||
| Pros: | ||
| - Type safety for instantiated configs | ||
| Cons: | ||
| - Requires modifying target classes (not feasible for external libraries - Need to use wrapper) | ||
| - Boilerplate (every class needs Config + __init__) | ||
| eg: | ||
| ``` | ||
| class TokenizerWithConfig: | ||
| @dataclass | ||
| class Config: | ||
| path: str | ||
| def build(self) -> "TokenizerWithConfig": | ||
| return TokenizerWithConfig(self) | ||
| def __init__(self, config: Config): | ||
| self.config = config | ||
| self.path = config.path | ||
| ``` | ||
| """ | ||
| 
     | 
||
| from dataclasses import dataclass | ||
| 
     | 
||
| import torch | ||
| 
     | 
||
| from mock_with_config import ( | ||
| ComponentConfig, | ||
| LlamaModelWithConfig, | ||
| MultiHeadAttentionWithConfig, | ||
| TokenizerWithConfig, | ||
| ) | ||
| 
     | 
||
| 
     | 
||
| @dataclass | ||
| class DataArgs: | ||
| """Plain dataclass for non-instantiated config block (PATTERN 4).""" | ||
| 
     | 
||
| batch_size: int = 4 | ||
| shuffle: bool = True | ||
| 
     | 
||
| 
     | 
||
| def llama3_2_1b_full(): | ||
| output_dir = "/tmp/torchtune/llama3_2_1B/full" | ||
| 
     | 
||
| return { | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For this one and the fiddle one, I guess you can take this to extreme and make everything a (data)class, e.g.  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My issue with this pattern of  Between the two, i personally prefer fiddle. But when i read  On the composability side, if you look at  TDLR 
  | 
||
| "output_dir": output_dir, | ||
| # PATTERN 1: Simple Component Instantiation | ||
| "tokenizer": TokenizerWithConfig.Config( | ||
| path="/tmp/Llama-3.2-1B-Instruct/original/tokenizer.model", | ||
| ), | ||
| # PATTERN 2: Component with Nested Instantiation | ||
| "model": LlamaModelWithConfig.Config( | ||
| attn_config=MultiHeadAttentionWithConfig.Config( | ||
| num_heads=32, | ||
| ) | ||
| ), | ||
| # PATTERN 3: Component Needing Runtime Args (Partial) | ||
| "optimizer": ComponentConfig( | ||
| component_cls=torch.optim.AdamW, | ||
| kwargs={"lr": 2e-5}, | ||
| ), | ||
| # PATTERN 4: Non-Instantiated Config Block (Plain Data) | ||
| "data_args": DataArgs( | ||
| batch_size=4, | ||
| shuffle=True, | ||
| ), | ||
| # PATTERN 5: Plain Top-Level Hyperparameters | ||
| "epochs": 1, | ||
| "gradient_accumulation_steps": 8, | ||
| } | ||
| 
     | 
||
| 
     | 
||
| if __name__ == "__main__": | ||
| # ========================================================================= | ||
| # Scenario 1: Basic Instantiation | ||
| # ========================================================================= | ||
| cfg = llama3_2_1b_full() | ||
| 
     | 
||
| # PATTERN 1: Simple Component Instantiation | ||
| tokenizer = cfg["tokenizer"].build() | ||
| 
     | 
||
| # PATTERN 2: Component with Nested Instantiation | ||
| model = cfg["model"].build() | ||
| 
     | 
||
| # PATTERN 3: Component Needing Runtime Args (Partial) | ||
| optimizer = cfg["optimizer"].build(model.parameters()) | ||
| 
     | 
||
| # ========================================================================= | ||
| # Scenario 2: Override Config Values | ||
| # ========================================================================= | ||
| cfg2 = llama3_2_1b_full() | ||
| 
     | 
||
| # PATTERN 1: Simple Component Instantiation | ||
| cfg2["tokenizer"].path = "/new/path" | ||
| 
     | 
||
| # PATTERN 2: Component with Nested Instantiation | ||
| cfg2["model"].attn_config.num_heads = 64 | ||
| 
     | 
||
| # PATTERN 3: Component Needing Runtime Args (Partial) | ||
| cfg2["optimizer"].kwargs["lr"] = 1e-4 | ||
| 
     | 
||
| model2 = cfg2["model"].build() | ||
| optimizer2 = cfg2["optimizer"].build(model2.parameters()) | ||
| 
     | 
||
| # ========================================================================= | ||
| # Scenario 3: Config Composition | ||
| # ========================================================================= | ||
| def llama3_2_1b_large_lr(): | ||
| """Variant with larger learning rate and different model config.""" | ||
| base = llama3_2_1b_full() | ||
| # Overrides | ||
| base["optimizer"].kwargs["lr"] = 1e-3 | ||
| base["model"].attn_config.num_heads = 64 | ||
| return base | ||
| 
     | 
||
| cfg_variant = llama3_2_1b_large_lr() | ||
| model_variant = cfg_variant["model"].build() | ||
| optimizer_variant = cfg_variant["optimizer"].build(model_variant.parameters()) | ||
| assert optimizer_variant.param_groups[0]["lr"] == 1e-3 | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| 
     | 
||
| """ | ||
| Config using Plain Python Dicts. | ||
| 
     | 
||
| Pros: | ||
| - Extremely simple | ||
| - No dependencies | ||
| - Easy to understand | ||
| - Flexible | ||
| 
     | 
||
| Cons: | ||
| - No type hints (cfg["batch_szie"] typo won't be caught) | ||
| - No validation (cfg["batch_size"] = "invalid" won't error) | ||
| - Very loose, users can pass anything | ||
| """ | ||
| 
     | 
||
| import torch.optim | ||
| 
     | 
||
| from mock import llama3_2_1b, llama3_tokenizer, MultiHeadAttention | ||
| 
     | 
||
| 
     | 
||
| def llama3_2_1b_full(): | ||
| output_dir = "/tmp/torchtune/llama3_2_1B/full" | ||
| batch_size = 4 | ||
| 
     | 
||
| return { | ||
| "output_dir": output_dir, | ||
| # PATTERN 1: Simple Component Instantiation | ||
| "tokenizer": { | ||
| "cls": llama3_tokenizer, | ||
| "kwargs": { | ||
| "path": "/tmp/Llama-3.2-1B-Instruct/original/tokenizer.model", | ||
| }, | ||
| }, | ||
| # PATTERN 2: Component with Nested Instantiation | ||
| "model": { | ||
| "cls": llama3_2_1b, | ||
| "kwargs": { | ||
| "attn_config": { | ||
| "cls": MultiHeadAttention, | ||
| "kwargs": { | ||
| "num_heads": 32, | ||
| }, | ||
| } | ||
| }, | ||
| }, | ||
| # PATTERN 3: Component Needing Runtime Args (Partial) | ||
| "optimizer": { | ||
| "cls": torch.optim.AdamW, | ||
| "kwargs": { | ||
| "lr": 2e-5, | ||
| }, | ||
| }, | ||
| # PATTERN 4: Non-Instantiated Config Block (Plain Data) | ||
| "data_args": { | ||
| "batch_size": batch_size, | ||
| "shuffle": True, | ||
| }, | ||
| # PATTERN 5: Plain Top-Level Hyperparameters | ||
| "epochs": 1, | ||
| "gradient_accumulation_steps": 8, | ||
| } | ||
| 
     | 
||
| 
     | 
||
| if __name__ == "__main__": | ||
| # ========================================================================= | ||
| # Scenario 1: Basic Instantiation | ||
| # ========================================================================= | ||
| cfg = llama3_2_1b_full() | ||
| 
     | 
||
| # PATTERN 1: Simple Component Instantiation | ||
| tokenizer = cfg["tokenizer"]["cls"](**cfg["tokenizer"]["kwargs"]) | ||
| 
     | 
||
| # PATTERN 2: Component with Nested Instantiation | ||
| attn_config = cfg["model"]["kwargs"]["attn_config"]["cls"]( | ||
| **cfg["model"]["kwargs"]["attn_config"]["kwargs"] | ||
| ) | ||
| model = cfg["model"]["cls"](attn_config=attn_config) | ||
| 
     | 
||
| # PATTERN 3: Component Needing Runtime Args (Partial) | ||
| optimizer = cfg["optimizer"]["cls"]( | ||
| model.parameters(), **cfg["optimizer"]["kwargs"] | ||
| ) | ||
| 
     | 
||
| # ========================================================================= | ||
| # Scenario 2: Override Config Values | ||
| # ========================================================================= | ||
| cfg2 = llama3_2_1b_full() | ||
| 
     | 
||
| # PATTERN 1: Simple Component Instantiation | ||
| cfg2["tokenizer"]["kwargs"]["path"] = "/new/tokenizer" | ||
| 
     | 
||
| # PATTERN 2: Component with Nested Instantiation | ||
| cfg2["model"]["kwargs"]["attn_config"]["kwargs"]["num_heads"] = 64 | ||
| 
     | 
||
| # PATTERN 3: Component Needing Runtime Args (Partial) | ||
| cfg2["optimizer"]["kwargs"]["lr"] = 1e-4 | ||
| 
     | 
||
| model2 = cfg2["model"]["cls"]( | ||
| attn_config=cfg2["model"]["kwargs"]["attn_config"]["cls"]( | ||
| **cfg2["model"]["kwargs"]["attn_config"]["kwargs"] | ||
| ) | ||
| ) | ||
| optimizer2 = cfg2["optimizer"]["cls"]( | ||
| model2.parameters(), **cfg2["optimizer"]["kwargs"] | ||
| ) | ||
| 
     | 
||
| # ========================================================================= | ||
| # Scenario 3: Config Composition | ||
| # ========================================================================= | ||
| def llama3_2_1b_large_lr(): | ||
| """Variant with larger learning rate.""" | ||
| base = llama3_2_1b_full() | ||
| base["optimizer"]["kwargs"]["lr"] = 1e-3 | ||
| base["model"]["kwargs"]["attn_config"]["kwargs"]["num_heads"] = 64 | ||
| return base | ||
| 
     | 
||
| cfg_variant = llama3_2_1b_large_lr() | ||
| attn_config_variant = cfg_variant["model"]["kwargs"]["attn_config"]["cls"]( | ||
| 
         There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems too flexible, hard to read, and error-prone There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed, its horrible. I put it here to make a contrast  | 
||
| **cfg_variant["model"]["kwargs"]["attn_config"]["kwargs"] | ||
| ) | ||
| model_variant = cfg_variant["model"]["cls"](attn_config=attn_config_variant) | ||
| optimizer_variant = cfg_variant["optimizer"]["cls"]( | ||
| model_variant.parameters(), **cfg_variant["optimizer"]["kwargs"] | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we get free cli override with tyro?
https://github.com/pytorch/torchtitan/blob/2ea6197b957936bdd4941e59a000cf31987a3184/torchtitan/config/manager.py#L56
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that looks nice, didnt know about it