Skip to content

Commit 3106fea

Browse files
haozha111copybara-github
authored andcommitted
Clean up export configs.
PiperOrigin-RevId: 744102468
1 parent 9898978 commit 3106fea

33 files changed

+126
-95
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
"""Example of converting AMD-Llama-135m model to multi-signature tflite model."""
1717

18-
import os
1918
from absl import app
20-
from absl import flags
2119
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
2220
from ai_edge_torch.generative.utilities import converter
23-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
21+
from ai_edge_torch.generative.utilities import export_config
2422

2523
flags = converter.define_conversion_flags("amd-llama-135m")
24+
ExportConfig = export_config.ExportConfig
25+
2626

2727
def main(_):
2828
pytorch_model = amd_llama_135m.build_model(

ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515

1616
"""Example of converting DeepSeek R1 distilled models to tflite model."""
1717

18-
import os
19-
import pathlib
20-
2118
from absl import app
22-
from absl import flags
2319
from ai_edge_torch.generative.examples.deepseek import deepseek
2420
from ai_edge_torch.generative.utilities import converter
25-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
21+
from ai_edge_torch.generative.utilities import export_config
2622

2723
flags = converter.define_conversion_flags("deepseek")
24+
ExportConfig = export_config.ExportConfig
2825

2926
def main(_):
3027
pytorch_model = deepseek.build_model(

ai_edge_torch/generative/examples/experimental/gemma/convert_gemma2_gpu_to_tflite.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,15 @@
1515

1616
"""Example of converting a Gemma2 model to multi-signature tflite model."""
1717

18-
import os
19-
import pathlib
20-
2118
from absl import app
22-
from absl import flags
2319
from ai_edge_torch.generative.examples.experimental.gemma import gemma2_gpu
2420
from ai_edge_torch.generative.layers.experimental import kv_cache
2521
from ai_edge_torch.generative.utilities import converter
26-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
22+
from ai_edge_torch.generative.utilities import export_config
2723
import torch
2824

2925
flags = converter.define_conversion_flags('gemma2-2b')
26+
ExportConfig = export_config.ExportConfig
3027

3128
def _create_mask(mask_len, kv_cache_max_len):
3229
mask = torch.full(

ai_edge_torch/generative/examples/experimental/gemma/gemma2_gpu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
3131
import ai_edge_torch.generative.layers.model_config as cfg
3232
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
33+
from ai_edge_torch.generative.utilities import export_config as export_cfg
3334
from ai_edge_torch.generative.utilities import model_builder
3435
import ai_edge_torch.generative.utilities.loader as loading_utils
3536
import torch
@@ -152,7 +153,7 @@ def forward(
152153
input_pos: torch.Tensor,
153154
kv_cache: kv_utils.KVCacheBase,
154155
mask: Optional[torch.Tensor] = None,
155-
export_config: Optional[model_builder.ExportConfig] = None,
156+
export_config: Optional[export_cfg.ExportConfig] = None,
156157
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
157158
_, seq_len = tokens.size()
158159
assert self.config.max_seq_len >= seq_len, (
@@ -185,7 +186,7 @@ def _forward_with_embeds(
185186
mask: torch.Tensor | List[torch.Tensor],
186187
input_pos: torch.Tensor,
187188
kv_cache: kv_utils.KVCacheBase,
188-
export_config: Optional[model_builder.ExportConfig] = None,
189+
export_config: Optional[export_cfg.ExportConfig] = None,
189190
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
190191
"""Forwards the model with input embeddings."""
191192
assert len(self.transformer_blocks) == len(kv_cache.caches), (

ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515

1616
"""Example of converting a Gemma1 model to multi-signature tflite model."""
1717

18-
import os
1918
from absl import app
2019
from ai_edge_torch.generative.examples.gemma import gemma1
2120
from ai_edge_torch.generative.utilities import converter
22-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
21+
from ai_edge_torch.generative.utilities import export_config
2322

2423
flags = converter.define_conversion_flags("gemma-2b")
24+
ExportConfig = export_config.ExportConfig
25+
2526

2627
def main(_):
2728
pytorch_model = gemma1.build_2b_model(

ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616
"""Example of converting a Gemma2 model to multi-signature tflite model."""
1717

18-
import os
1918
from absl import app
20-
from absl import flags
2119
from ai_edge_torch.generative.examples.gemma import gemma2
2220
from ai_edge_torch.generative.utilities import converter
23-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
21+
from ai_edge_torch.generative.utilities import export_config
2422

2523
flags = converter.define_conversion_flags("gemma2-2b")
24+
ExportConfig = export_config.ExportConfig
25+
2626

2727
def main(_):
2828
pytorch_model = gemma2.build_2b_model(

ai_edge_torch/generative/examples/gemma/gemma2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import ai_edge_torch.generative.layers.attention_utils as attn_utils
2424
import ai_edge_torch.generative.layers.model_config as cfg
2525
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26+
from ai_edge_torch.generative.utilities import export_config as export_cfg
2627
from ai_edge_torch.generative.utilities import model_builder
2728
import ai_edge_torch.generative.utilities.loader as loading_utils
2829
import torch
@@ -151,7 +152,7 @@ def forward(
151152
input_pos: torch.Tensor,
152153
kv_cache: kv_utils.KVCache,
153154
mask: Optional[torch.Tensor] = None,
154-
export_config: Optional[model_builder.ExportConfig] = None,
155+
export_config: Optional[export_cfg.ExportConfig] = None,
155156
) -> dict[torch.Tensor, kv_utils.KVCache]:
156157
_, seq_len = tokens.size()
157158
assert self.config.max_seq_len >= seq_len, (
@@ -184,7 +185,7 @@ def _forward_with_embeds(
184185
mask: torch.Tensor | List[torch.Tensor],
185186
input_pos: torch.Tensor,
186187
kv_cache: kv_utils.KVCache,
187-
export_config: Optional[model_builder.ExportConfig] = None,
188+
export_config: Optional[export_cfg.ExportConfig] = None,
188189
) -> dict[torch.Tensor, kv_utils.KVCache]:
189190
"""Forwards the model with input embeddings."""
190191
assert len(self.transformer_blocks) == len(kv_cache.caches), (

ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515

1616
"""Example of converting a Gemma3 model to multi-signature tflite model."""
1717

18-
import os
1918
from absl import app
20-
from absl import flags
2119
from ai_edge_torch.generative.examples.gemma3 import gemma3
2220
from ai_edge_torch.generative.layers.experimental import kv_cache
2321
from ai_edge_torch.generative.utilities import converter
24-
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
22+
from ai_edge_torch.generative.utilities import export_config
2523
import torch
2624

2725
flags = converter.define_conversion_flags('gemma3-1b')
26+
ExportConfig = export_config.ExportConfig
27+
2828

2929
_MODEL_SIZE = flags.DEFINE_string(
3030
'model_size',

ai_edge_torch/generative/examples/gemma3/decoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
2424
import ai_edge_torch.generative.layers.model_config as cfg
2525
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26+
from ai_edge_torch.generative.utilities import export_config as export_cfg
2627
from ai_edge_torch.generative.utilities import model_builder
2728
import ai_edge_torch.generative.utilities.loader as loading_utils
2829
import torch
@@ -244,7 +245,7 @@ def forward(
244245
input_embeds: Optional[torch.Tensor] = None,
245246
mask: Optional[torch.Tensor] = None,
246247
image_indices: Optional[torch.Tensor] = None,
247-
export_config: Optional[model_builder.ExportConfig] = None,
248+
export_config: Optional[export_cfg.ExportConfig] = None,
248249
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
249250

250251
pixel_mask = None
@@ -288,7 +289,7 @@ def _forward_with_embeds(
288289
input_pos: torch.Tensor,
289290
kv_cache: kv_utils.KVCacheBase,
290291
pixel_mask: Optional[torch.Tensor] = None,
291-
export_config: Optional[model_builder.ExportConfig] = None,
292+
export_config: Optional[export_cfg.ExportConfig] = None,
292293
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
293294
"""Forwards the model with input embeddings."""
294295
assert len(self.transformer_blocks) == len(kv_cache.caches), (

ai_edge_torch/generative/examples/gemma3/gemma3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ai_edge_torch.generative.layers import builder
2525
from ai_edge_torch.generative.layers import kv_cache as kv_utils
2626
import ai_edge_torch.generative.layers.model_config as cfg
27-
from ai_edge_torch.generative.utilities import model_builder
27+
from ai_edge_torch.generative.utilities import export_config as export_cfg
2828
import ai_edge_torch.generative.utilities.loader as loading_utils
2929
import torch
3030
from torch import nn
@@ -83,7 +83,7 @@ def forward(
8383
image_indices: Optional[torch.Tensor] = None,
8484
image_feat_indices: Optional[torch.Tensor] = None,
8585
pixel_values: torch.Tensor = None,
86-
export_config: Optional[model_builder.ExportConfig] = None,
86+
export_config: Optional[export_cfg.ExportConfig] = None,
8787
) -> dict[torch.Tensor, kv_utils.KVCache]:
8888
_, seq_len = tokens.size()
8989
assert self.config.decoder_config.max_seq_len >= seq_len, (
@@ -150,6 +150,7 @@ def forward(
150150
export_config=export_config,
151151
)
152152

153+
153154
def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
154155
return Gemma3MMConfig(
155156
image_encoder_config=image_encoder.get_fake_image_encoder_config(),

0 commit comments

Comments
 (0)