Skip to content

Commit 2d95e4a

Browse files
author
Felipe Mello
committed
first commit
1 parent 801a454 commit 2d95e4a

File tree

11 files changed

+838
-0
lines changed

11 files changed

+838
-0
lines changed

apps/sft/eval_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""Utility functions for evaluation to make main.py more concise."""
28

39
import logging

brainstorming/configs/base.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Base config class for all approaches.
9+
10+
In a real project, this would live in forge/config/base.py and be imported as:
11+
from forge.config.base import FullFinetuneConfig
12+
"""
13+
14+
from dataclasses import dataclass
15+
from typing import Any
16+
17+
18+
@dataclass
19+
class FullFinetuneConfig:
20+
"""
21+
Base config for all approaches.
22+
23+
Uses `Any` type hints for components that differ across approaches:
24+
- Plain dicts: str/dict
25+
- Fiddle: fdl.Config/fdl.Partial
26+
- Partial: functools.partial
27+
- Hydra: Spec dataclasses
28+
- Dataclasses: Inner Config classes
29+
"""
30+
31+
output_dir: str
32+
33+
# Components (type varies by approach)
34+
tokenizer: Any
35+
model: Any
36+
optimizer: Any
37+
data_args: Any
38+
39+
# Plain hyperparameters (same across all approaches)
40+
epochs: int = 1
41+
gradient_accumulation_steps: int = 8
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# baseline.yaml shows 5 key config patterns
2+
3+
output_dir: /tmp/torchtune/llama3_2_1B/full
4+
5+
# PATTERN 1: Simple Component Instantiation
6+
tokenizer:
7+
_target_: mock.llama3_tokenizer
8+
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
9+
10+
# PATTERN 2: Component with Nested Instantiation
11+
model:
12+
_target_: mock.llama3_2_1b
13+
# Nested component: attention config
14+
attn_config:
15+
_target_: mock.MultiHeadAttention
16+
num_heads: 32
17+
18+
# PATTERN 3: Component Needing Runtime Args (Partial)
19+
optimizer:
20+
_target_: torch.optim.AdamW
21+
lr: 2e-5
22+
_partial_: true
23+
# params: None #will be passed at instantiation time (not known now)
24+
25+
# PATTERN 4: Non-Instantiated Config Block (Plain Data)
26+
data_args:
27+
batch_size: 4
28+
shuffle: True
29+
30+
# PATTERN 5: Plain Top-Level Hyperparameters
31+
# Training params
32+
epochs: 1
33+
gradient_accumulation_steps: 8
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Dataclass config with inner Config classes.
9+
10+
Pros:
11+
- Type safety for instantiated configs
12+
13+
Cons:
14+
- Requires modifying target classes (not feasible for external libraries - Need to use wrapper)
15+
- Boilerplate (every class needs Config + __init__)
16+
17+
eg:
18+
19+
```
20+
class TokenizerWithConfig:
21+
@dataclass
22+
class Config:
23+
path: str
24+
25+
def build(self) -> "TokenizerWithConfig":
26+
return TokenizerWithConfig(self)
27+
28+
def __init__(self, config: Config):
29+
self.config = config
30+
self.path = config.path
31+
```
32+
"""
33+
34+
from dataclasses import dataclass
35+
36+
import torch
37+
38+
from mock_with_config import (
39+
ComponentConfig,
40+
LlamaModelWithConfig,
41+
MultiHeadAttentionWithConfig,
42+
TokenizerWithConfig,
43+
)
44+
45+
46+
@dataclass
47+
class DataArgs:
48+
"""Plain dataclass for non-instantiated config block (PATTERN 4)."""
49+
50+
batch_size: int = 4
51+
shuffle: bool = True
52+
53+
54+
def llama3_2_1b_full():
55+
output_dir = "/tmp/torchtune/llama3_2_1B/full"
56+
57+
return {
58+
"output_dir": output_dir,
59+
# PATTERN 1: Simple Component Instantiation
60+
"tokenizer": TokenizerWithConfig.Config(
61+
path="/tmp/Llama-3.2-1B-Instruct/original/tokenizer.model",
62+
),
63+
# PATTERN 2: Component with Nested Instantiation
64+
"model": LlamaModelWithConfig.Config(
65+
attn_config=MultiHeadAttentionWithConfig.Config(
66+
num_heads=32,
67+
)
68+
),
69+
# PATTERN 3: Component Needing Runtime Args (Partial)
70+
"optimizer": ComponentConfig(
71+
component_cls=torch.optim.AdamW,
72+
kwargs={"lr": 2e-5},
73+
),
74+
# PATTERN 4: Non-Instantiated Config Block (Plain Data)
75+
"data_args": DataArgs(
76+
batch_size=4,
77+
shuffle=True,
78+
),
79+
# PATTERN 5: Plain Top-Level Hyperparameters
80+
"epochs": 1,
81+
"gradient_accumulation_steps": 8,
82+
}
83+
84+
85+
if __name__ == "__main__":
86+
# =========================================================================
87+
# Scenario 1: Basic Instantiation
88+
# =========================================================================
89+
cfg = llama3_2_1b_full()
90+
91+
# PATTERN 1: Simple Component Instantiation
92+
tokenizer = cfg["tokenizer"].build()
93+
94+
# PATTERN 2: Component with Nested Instantiation
95+
model = cfg["model"].build()
96+
97+
# PATTERN 3: Component Needing Runtime Args (Partial)
98+
optimizer = cfg["optimizer"].build(model.parameters())
99+
100+
# =========================================================================
101+
# Scenario 2: Override Config Values
102+
# =========================================================================
103+
cfg2 = llama3_2_1b_full()
104+
105+
# PATTERN 1: Simple Component Instantiation
106+
cfg2["tokenizer"].path = "/new/path"
107+
108+
# PATTERN 2: Component with Nested Instantiation
109+
cfg2["model"].attn_config.num_heads = 64
110+
111+
# PATTERN 3: Component Needing Runtime Args (Partial)
112+
cfg2["optimizer"].kwargs["lr"] = 1e-4
113+
114+
model2 = cfg2["model"].build()
115+
optimizer2 = cfg2["optimizer"].build(model2.parameters())
116+
117+
# =========================================================================
118+
# Scenario 3: Config Composition
119+
# =========================================================================
120+
def llama3_2_1b_large_lr():
121+
"""Variant with larger learning rate and different model config."""
122+
base = llama3_2_1b_full()
123+
# Overrides
124+
base["optimizer"].kwargs["lr"] = 1e-3
125+
base["model"].attn_config.num_heads = 64
126+
return base
127+
128+
cfg_variant = llama3_2_1b_large_lr()
129+
model_variant = cfg_variant["model"].build()
130+
optimizer_variant = cfg_variant["optimizer"].build(model_variant.parameters())
131+
assert optimizer_variant.param_groups[0]["lr"] == 1e-3
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Config using Plain Python Dicts.
9+
10+
Pros:
11+
- Extremely simple
12+
- No dependencies
13+
- Easy to understand
14+
- Flexible
15+
16+
Cons:
17+
- No type hints (cfg["batch_szie"] typo won't be caught)
18+
- No validation (cfg["batch_size"] = "invalid" won't error)
19+
- Very loose, users can pass anything
20+
"""
21+
22+
import torch.optim
23+
24+
from mock import llama3_2_1b, llama3_tokenizer, MultiHeadAttention
25+
26+
27+
def llama3_2_1b_full():
28+
output_dir = "/tmp/torchtune/llama3_2_1B/full"
29+
batch_size = 4
30+
31+
return {
32+
"output_dir": output_dir,
33+
# PATTERN 1: Simple Component Instantiation
34+
"tokenizer": {
35+
"cls": llama3_tokenizer,
36+
"kwargs": {
37+
"path": "/tmp/Llama-3.2-1B-Instruct/original/tokenizer.model",
38+
},
39+
},
40+
# PATTERN 2: Component with Nested Instantiation
41+
"model": {
42+
"cls": llama3_2_1b,
43+
"kwargs": {
44+
"attn_config": {
45+
"cls": MultiHeadAttention,
46+
"kwargs": {
47+
"num_heads": 32,
48+
},
49+
}
50+
},
51+
},
52+
# PATTERN 3: Component Needing Runtime Args (Partial)
53+
"optimizer": {
54+
"cls": torch.optim.AdamW,
55+
"kwargs": {
56+
"lr": 2e-5,
57+
},
58+
},
59+
# PATTERN 4: Non-Instantiated Config Block (Plain Data)
60+
"data_args": {
61+
"batch_size": batch_size,
62+
"shuffle": True,
63+
},
64+
# PATTERN 5: Plain Top-Level Hyperparameters
65+
"epochs": 1,
66+
"gradient_accumulation_steps": 8,
67+
}
68+
69+
70+
if __name__ == "__main__":
71+
# =========================================================================
72+
# Scenario 1: Basic Instantiation
73+
# =========================================================================
74+
cfg = llama3_2_1b_full()
75+
76+
# PATTERN 1: Simple Component Instantiation
77+
tokenizer = cfg["tokenizer"]["cls"](**cfg["tokenizer"]["kwargs"])
78+
79+
# PATTERN 2: Component with Nested Instantiation
80+
attn_config = cfg["model"]["kwargs"]["attn_config"]["cls"](
81+
**cfg["model"]["kwargs"]["attn_config"]["kwargs"]
82+
)
83+
model = cfg["model"]["cls"](attn_config=attn_config)
84+
85+
# PATTERN 3: Component Needing Runtime Args (Partial)
86+
optimizer = cfg["optimizer"]["cls"](
87+
model.parameters(), **cfg["optimizer"]["kwargs"]
88+
)
89+
90+
# =========================================================================
91+
# Scenario 2: Override Config Values
92+
# =========================================================================
93+
cfg2 = llama3_2_1b_full()
94+
95+
# PATTERN 1: Simple Component Instantiation
96+
cfg2["tokenizer"]["kwargs"]["path"] = "/new/tokenizer"
97+
98+
# PATTERN 2: Component with Nested Instantiation
99+
cfg2["model"]["kwargs"]["attn_config"]["kwargs"]["num_heads"] = 64
100+
101+
# PATTERN 3: Component Needing Runtime Args (Partial)
102+
cfg2["optimizer"]["kwargs"]["lr"] = 1e-4
103+
104+
model2 = cfg2["model"]["cls"](
105+
attn_config=cfg2["model"]["kwargs"]["attn_config"]["cls"](
106+
**cfg2["model"]["kwargs"]["attn_config"]["kwargs"]
107+
)
108+
)
109+
optimizer2 = cfg2["optimizer"]["cls"](
110+
model2.parameters(), **cfg2["optimizer"]["kwargs"]
111+
)
112+
113+
# =========================================================================
114+
# Scenario 3: Config Composition
115+
# =========================================================================
116+
def llama3_2_1b_large_lr():
117+
"""Variant with larger learning rate."""
118+
base = llama3_2_1b_full()
119+
base["optimizer"]["kwargs"]["lr"] = 1e-3
120+
base["model"]["kwargs"]["attn_config"]["kwargs"]["num_heads"] = 64
121+
return base
122+
123+
cfg_variant = llama3_2_1b_large_lr()
124+
attn_config_variant = cfg_variant["model"]["kwargs"]["attn_config"]["cls"](
125+
**cfg_variant["model"]["kwargs"]["attn_config"]["kwargs"]
126+
)
127+
model_variant = cfg_variant["model"]["cls"](attn_config=attn_config_variant)
128+
optimizer_variant = cfg_variant["optimizer"]["cls"](
129+
model_variant.parameters(), **cfg_variant["optimizer"]["kwargs"]
130+
)

0 commit comments

Comments
 (0)