Skip to content

Commit 8e53f94

Browse files
protobird-gitcopybara-github
authored andcommitted
Apply mask_as_input and transpose_kv_cache flags to OpenELM and AMD Llama
- It's to make them compatible to other models - OpenELM is working on CPU regardless of this CL or flags - OpenELM is NOT working on GPU regardless of this CL or flags - AMD Llama is NOT working regardless of this CL or flags PiperOrigin-RevId: 757836926
1 parent 869f6ad commit 8e53f94

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
4949
activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
5050
intermediate_size=2048,
5151
)
52-
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
52+
norm_config = cfg.NormalizationConfig(
53+
type=cfg.NormalizationType.RMS_NORM, enable_hlfb=True
54+
)
5355
block_config = cfg.TransformerBlockConfig(
5456
attn_config=attn_config,
5557
ff_config=ff_config,

ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ai_edge_torch.generative.utilities import export_config
2222

2323
flags = converter.define_conversion_flags("amd-llama-135m")
24-
ExportConfig = export_config.ExportConfig
2524

2625

2726
def main(_):
@@ -35,7 +34,7 @@ def main(_):
3534
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
3635
quantize=flags.FLAGS.quantize,
3736
lora_ranks=flags.FLAGS.lora_ranks,
38-
export_config=ExportConfig(),
37+
export_config=export_config.get_from_flags(),
3938
)
4039

4140

ai_edge_torch/generative/examples/openelm/convert_to_tflite.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ai_edge_torch.generative.utilities import export_config
2222

2323
flags = converter.define_conversion_flags("openelm")
24-
ExportConfig = export_config.ExportConfig
2524

2625

2726
def main(_):
@@ -35,7 +34,7 @@ def main(_):
3534
prefill_seq_len=flags.FLAGS.prefill_seq_lens,
3635
quantize=flags.FLAGS.quantize,
3736
lora_ranks=flags.FLAGS.lora_ranks,
38-
export_config=ExportConfig(),
37+
export_config=export_config.get_from_flags(),
3938
)
4039

4140

ai_edge_torch/generative/examples/openelm/openelm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
5151
The model config for an OpenELM model.
5252
"""
5353
norm_config = cfg.NormalizationConfig(
54-
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6
54+
type=cfg.NormalizationType.RMS_NORM, epsilon=1e-6, enable_hlfb=True
5555
)
5656
num_heads = [12] * 4 + [16] * 14 + [20] * 12 + [24] * 6
5757
num_query_groups = [3] * 4 + [4] * 14 + [5] * 12 + [6] * 6

0 commit comments

Comments
 (0)