Skip to content

Commit 8684660

Browse files
committed
[Custom model] add document and example
1 parent 03d3739 commit 8684660

File tree

18 files changed

+2607
-12
lines changed

18 files changed

+2607
-12
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,26 @@ To set up multi-node training:
474474

475475
`torchtitan` provides a [Slurm script](https://github.com/pytorch/torchtitan/blob/main/multinode_trainer.slurm) for multi-node training, which you can use as a reference or starting point.
476476

477+
## Custom models
478+
479+
`flame` supports custom model architectures through seamless integration with the Hugging Face `transformers` library. To add your own model:
480+
481+
1. Create a new model directory under `custom_models/` (see `custom_models/sba` for a complete example)
482+
2. Implement your model classes and configuration:
483+
- Define a config class inheriting from `PretrainedConfig` (see `custom_models/sba/config_sba.py` for an example)
484+
- Create model classes inheriting from `PreTrainedModel` (see `custom_models/sba/modeling_sba.py` for an example)
485+
3. Register your models in `__init__.py`:
486+
- Import your model classes and config classes
487+
- Register your models with the `AutoModelForCausalLM`, `AutoModel` and `AutoConfig` classes (see `custom_models/sba/__init__.py` for an example)
488+
4. Create a config file for your custom model, just need to specify the `model_type` to the one you just named for your custom model (example: `configs/sba_340m.json`).
489+
5. Training is extremely simple, you can just use the `flame.train.py` script to train your custom model.
490+
491+
492+
493+
494+
495+
496+
477497
## Citation
478498

479499
If you find `flame` helpful for your work, please consider citing it.

configs/sba_340m.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"attention_bias": false,
3+
"bos_token_id": 1,
4+
"eos_token_id": 2,
5+
"fuse_cross_entropy": true,
6+
"fuse_norm": true,
7+
"hidden_act": "swish",
8+
"hidden_size": 1024,
9+
"initializer_range": 0.006,
10+
"max_position_embeddings": 8192,
11+
"model_type": "sba",
12+
"num_heads": 16,
13+
"num_hidden_layers": 24,
14+
"norm_eps": 1e-06,
15+
"tie_word_embeddings": false,
16+
"use_cache": true,
17+
"vocab_size": 32000
18+
}

custom_models/sba/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .config_sba import SBAConfig
2+
from .modeling_sba import SBAForCausalLM, SBAModel
3+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4+
5+
__all__ = ['SBAConfig', 'SBAForCausalLM', 'SBAModel']
6+
7+
AutoConfig.register('sba', SBAConfig)
8+
AutoModel.register(SBAConfig, SBAModel)
9+
AutoModelForCausalLM.register(SBAConfig, SBAForCausalLM)

custom_models/sba/config_sba.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from typing import Optional
4+
5+
from transformers.configuration_utils import PretrainedConfig
6+
7+
class SBAConfig(PretrainedConfig):
8+
model_type = 'sba'
9+
keys_to_ignore_at_inference = ['past_key_values']
10+
11+
def __init__(
12+
self,
13+
hidden_size: int = 2048,
14+
num_hidden_layers: int = 24,
15+
num_heads: int = 32,
16+
num_kv_heads: int = None,
17+
qkv_bias: bool = False,
18+
window_size: Optional[int] = None,
19+
rope_theta: Optional[float] = 10000.,
20+
max_position_embeddings: int = 2048,
21+
hidden_ratio: Optional[int] = 4,
22+
intermediate_size: Optional[int] = None,
23+
hidden_act: str = "swish",
24+
initializer_range: float = 0.006,
25+
elementwise_affine: Optional[bool] = True,
26+
norm_eps: float = 1e-6,
27+
use_cache: bool = True,
28+
pad_token_id: int = None,
29+
bos_token_id: int = 1,
30+
eos_token_id: int = 2,
31+
tie_word_embeddings: bool = False,
32+
fuse_norm: bool = True,
33+
fuse_swiglu: bool = True,
34+
fuse_cross_entropy: bool = True,
35+
vocab_size: int = 32000,
36+
**kwargs,
37+
):
38+
self.hidden_size = hidden_size
39+
self.num_hidden_layers = num_hidden_layers
40+
self.num_heads = num_heads
41+
self.num_kv_heads = num_kv_heads
42+
self.qkv_bias = qkv_bias
43+
self.window_size = window_size
44+
self.rope_theta = rope_theta
45+
self.max_position_embeddings = max_position_embeddings
46+
47+
self.hidden_ratio = hidden_ratio
48+
self.intermediate_size = intermediate_size
49+
self.hidden_act = hidden_act
50+
51+
self.initializer_range = initializer_range
52+
self.elementwise_affine = elementwise_affine
53+
self.norm_eps = norm_eps
54+
self.use_cache = use_cache
55+
56+
self.fuse_norm = fuse_norm
57+
self.fuse_swiglu = fuse_swiglu
58+
self.fuse_cross_entropy = fuse_cross_entropy
59+
self.vocab_size = vocab_size
60+
61+
super().__init__(
62+
pad_token_id=pad_token_id,
63+
bos_token_id=bos_token_id,
64+
eos_token_id=eos_token_id,
65+
tie_word_embeddings=tie_word_embeddings,
66+
**kwargs,
67+
)

0 commit comments

Comments
 (0)