Skip to content

Commit 5154a77

Browse files
ai-edge-botcopybara-github
authored andcommitted
Define its own class per Gen AI example as a useful debugging info.
PiperOrigin-RevId: 704539650
1 parent e029f9b commit 5154a77

File tree

14 files changed

+103
-81
lines changed

14 files changed

+103
-81
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,16 @@
1717

1818
import ai_edge_torch.generative.layers.model_config as cfg
1919
from ai_edge_torch.generative.utilities import model_builder
20+
from torch import nn
2021

2122
TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
2223

2324

25+
class AmdLlama(model_builder.DecoderOnlyModel):
26+
"""An AMD-Llama model built from the Edge Generative API layers."""
27+
pass
28+
29+
2430
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
2531
"""Returns the model config for an AMD-Llama-135m model.
2632
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
7278
return config
7379

7480

75-
def build_model(
76-
checkpoint_path: str, **kwargs
77-
) -> model_builder.DecoderOnlyModel:
81+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
7882
return model_builder.build_decoder_only_model(
7983
checkpoint_path=checkpoint_path,
8084
config=get_model_config(**kwargs),
8185
tensor_names=TENSOR_NAMES,
86+
model_class=AmdLlama
8287
)

ai_edge_torch/generative/examples/gemma/gemma1.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import ai_edge_torch.generative.layers.model_config as cfg
1919
from ai_edge_torch.generative.utilities import model_builder
2020
import ai_edge_torch.generative.utilities.loader as loading_utils
21+
from torch import nn
2122

2223
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
2324
ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -33,6 +34,11 @@
3334
)
3435

3536

37+
class Gemma1(model_builder.DecoderOnlyModel):
38+
"""A Gemma1 model built from the Edge Generative API layers."""
39+
pass
40+
41+
3642
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
3743
"""Returns the model config for a Gemma 2B model.
3844
@@ -91,11 +97,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
9197
return config
9298

9399

94-
def build_2b_model(
95-
checkpoint_path: str, **kwargs
96-
) -> model_builder.DecoderOnlyModel:
100+
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
97101
return model_builder.build_decoder_only_model(
98102
checkpoint_path=checkpoint_path,
99103
config=get_model_config_2b(**kwargs),
100104
tensor_names=TENSOR_NAMES,
105+
model_class=Gemma1,
101106
)

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2323
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2424
import ai_edge_torch.generative.layers.model_config as cfg
25+
from ai_edge_torch.generative.utilities import model_builder
2526
import ai_edge_torch.generative.utilities.loader as loading_utils
26-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
2727
import torch
2828
from torch import nn
2929

@@ -133,7 +133,7 @@ def forward(
133133
tokens: torch.Tensor,
134134
input_pos: torch.Tensor,
135135
kv_cache: kv_utils.KVCache,
136-
export_config: Optional[ExportConfig] = None,
136+
export_config: Optional[model_builder.ExportConfig] = None,
137137
) -> dict[torch.Tensor, kv_utils.KVCache]:
138138
_, seq_len = tokens.size()
139139
assert self.config.max_seq_len >= seq_len, (
@@ -259,11 +259,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
259259

260260

261261
def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
262-
config = get_model_config_2b(**kwargs)
263-
model = Gemma2(config)
264-
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
265-
# Since embedding and lm-head use the same weight, we need to set strict
266-
# to False.
267-
loader.load(model, strict=False)
268-
model.eval()
269-
return model
262+
return model_builder.build_decoder_only_model(
263+
checkpoint_path=checkpoint_path,
264+
config=get_model_config_2b(**kwargs),
265+
tensor_names=TENSOR_NAMES,
266+
model_class=Gemma2,
267+
)

ai_edge_torch/generative/examples/llama/llama.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import ai_edge_torch.generative.layers.model_config as cfg
2222
from ai_edge_torch.generative.utilities import model_builder
23-
import ai_edge_torch.generative.utilities.loader as loading_utils
2423
import torch
2524

2625
TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -177,23 +176,18 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
177176

178177
def _build_model(
179178
checkpoint_path: str, config: cfg.ModelConfig
180-
) -> model_builder.DecoderOnlyModel:
181-
model = Llama(config)
182-
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
183-
# Since embedding and lm-head use the same weight, we need to set strict
184-
# to False.
185-
loader.load(model, strict=False)
186-
model.eval()
187-
return model
188-
189-
190-
def build_1b_model(
191-
checkpoint_path: str, **kwargs
192-
) -> model_builder.DecoderOnlyModel:
179+
) -> torch.nn.Module:
180+
return model_builder.build_decoder_only_model(
181+
checkpoint_path=checkpoint_path,
182+
config=config,
183+
tensor_names=TENSOR_NAMES,
184+
model_class=Llama,
185+
)
186+
187+
188+
def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
193189
return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
194190

195191

196-
def build_3b_model(
197-
checkpoint_path: str, **kwargs
198-
) -> model_builder.DecoderOnlyModel:
192+
def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
199193
return _build_model(checkpoint_path, get_3b_model_config(**kwargs))

ai_edge_torch/generative/examples/openelm/openelm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import ai_edge_torch.generative.layers.model_config as cfg
1919
from ai_edge_torch.generative.utilities import model_builder
2020
import ai_edge_torch.generative.utilities.loader as loading_utils
21+
from torch import nn
2122

2223
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
2324
ff_up_proj="transformer.layers.{}.ffn.proj_1",
@@ -34,6 +35,11 @@
3435
)
3536

3637

38+
class OpenELM(model_builder.DecoderOnlyModel):
39+
"""An OpenELM model built from the Edge Generative API layers."""
40+
pass
41+
42+
3743
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
3844
"""Returns the model config for an OpenELM model.
3945
@@ -112,11 +118,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
112118
return config
113119

114120

115-
def build_model(
116-
checkpoint_path: str, **kwargs
117-
) -> model_builder.DecoderOnlyModel:
121+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
118122
return model_builder.build_decoder_only_model(
119123
checkpoint_path=checkpoint_path,
120124
config=get_model_config(**kwargs),
121125
tensor_names=TENSOR_NAMES,
126+
model_class=OpenELM,
122127
)

ai_edge_torch/generative/examples/paligemma/decoder.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,10 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
130130
return config
131131

132132

133-
def build_decoder(
134-
checkpoint_path: str, **kwargs
135-
) -> model_builder.DecoderOnlyModel:
136-
decoder = Decoder(get_decoder_config(**kwargs))
137-
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
138-
# Loose the strictness because only decoder is being loaded.
139-
loader.load(decoder, strict=False)
140-
decoder.eval()
141-
return decoder
133+
def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
134+
return model_builder.build_decoder_only_model(
135+
checkpoint_path=checkpoint_path,
136+
config=get_decoder_config(**kwargs),
137+
tensor_names=TENSOR_NAMES,
138+
model_class=Decoder,
139+
)

ai_edge_torch/generative/examples/phi/phi2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import ai_edge_torch.generative.layers.model_config as cfg
1919
from ai_edge_torch.generative.utilities import model_builder
2020
import ai_edge_torch.generative.utilities.loader as loading_utils
21+
from torch import nn
2122

2223
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
2324
ff_up_proj="model.layers.{}.mlp.fc1",
@@ -33,6 +34,11 @@
3334
)
3435

3536

37+
class Phi2(model_builder.DecoderOnlyModel):
38+
"""A Phi-2 model built from the Edge Generative API layers."""
39+
pass
40+
41+
3642
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
3743
"""Returns the model config for a Phi-2 model.
3844
@@ -92,11 +98,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
9298
return config
9399

94100

95-
def build_model(
96-
checkpoint_path: str, **kwargs
97-
) -> model_builder.DecoderOnlyModel:
101+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
98102
return model_builder.build_decoder_only_model(
99103
checkpoint_path=checkpoint_path,
100104
config=get_model_config(**kwargs),
101105
tensor_names=TENSOR_NAMES,
106+
model_class=Phi2,
102107
)

ai_edge_torch/generative/examples/phi/phi3.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,13 +207,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
207207
return config
208208

209209

210-
def build_model(
211-
checkpoint_path: str, **kwargs
212-
) -> model_builder.DecoderOnlyModel:
210+
def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
213211
"""Instantiates the model instance and load checkpoint if provided."""
214-
config = get_model_config(**kwargs)
215-
model = Phi3_5Mini(config)
216-
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
217-
loader.load(model)
218-
model.eval()
219-
return model
212+
return model_builder.build_decoder_only_model(
213+
checkpoint_path=checkpoint_path,
214+
config=get_model_config(**kwargs),
215+
tensor_names=TENSOR_NAMES,
216+
model_class=Phi3_5Mini,
217+
)

ai_edge_torch/generative/examples/qwen/qwen.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,16 @@
1717

1818
import ai_edge_torch.generative.layers.model_config as cfg
1919
from ai_edge_torch.generative.utilities import model_builder
20+
from torch import nn
2021

2122
TENSOR_NAMES = model_builder.TENSOR_NAMES
2223

2324

25+
class Qwen(model_builder.DecoderOnlyModel):
26+
"""A Qwen model built from the Edge Generative API layers."""
27+
pass
28+
29+
2430
def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
2531
"""Returns the model config for a Qwen 2.5 3B model.
2632
@@ -101,31 +107,28 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
101107
return config
102108

103109

104-
def build_3b_model(
105-
checkpoint_path: str, **kwargs
106-
) -> model_builder.DecoderOnlyModel:
110+
def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
107111
return model_builder.build_decoder_only_model(
108112
checkpoint_path=checkpoint_path,
109113
config=get_3b_model_config(**kwargs),
110114
tensor_names=TENSOR_NAMES,
115+
model_class=Qwen,
111116
)
112117

113118

114-
def build_1_5b_model(
115-
checkpoint_path: str, **kwargs
116-
) -> model_builder.DecoderOnlyModel:
119+
def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
117120
return model_builder.build_decoder_only_model(
118121
checkpoint_path=checkpoint_path,
119122
config=get_1_5b_model_config(**kwargs),
120123
tensor_names=TENSOR_NAMES,
124+
model_class=Qwen,
121125
)
122126

123127

124-
def build_0_5b_model(
125-
checkpoint_path: str, **kwargs
126-
) -> model_builder.DecoderOnlyModel:
128+
def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
127129
return model_builder.build_decoder_only_model(
128130
checkpoint_path=checkpoint_path,
129131
config=get_0_5b_model_config(**kwargs),
130132
tensor_names=TENSOR_NAMES,
133+
model_class=Qwen,
131134
)

ai_edge_torch/generative/examples/smollm/smollm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,16 @@
1717

1818
import ai_edge_torch.generative.layers.model_config as cfg
1919
from ai_edge_torch.generative.utilities import model_builder
20+
from torch import nn
2021

2122
TENSOR_NAMES = model_builder.TENSOR_NAMES
2223

2324

25+
class SmolLM(model_builder.DecoderOnlyModel):
26+
"""A SmolLM model built from the Edge Generative API layers."""
27+
pass
28+
29+
2430
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
2531
"""Returns the model config for a SmolLM 135M model.
2632
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
7278
return config
7379

7480

75-
def build_model(
76-
checkpoint_path: str, **kwargs
77-
) -> model_builder.DecoderOnlyModel:
81+
def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
7882
return model_builder.build_decoder_only_model(
7983
checkpoint_path=checkpoint_path,
8084
config=get_model_config(**kwargs),
8185
tensor_names=TENSOR_NAMES,
86+
model_class=SmolLM,
8287
)

0 commit comments

Comments
 (0)