Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions examples/models/llama/export_kvio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from torch.export import exported_program
import coremltools as ct

from llama_transformer_kvio import Transformer, ModelArgs

# Define model
import json
import torch

params_path = f"/Users/scroy/models/stories110M/params.json"
checkpoint_path = f"/Users/scroy/models/stories110M/stories110M.pt"
output_path = f"/Users/scroy/Desktop/exported2.pte"

with open(params_path, "r") as f:
params = json.loads(f.read())

model_args: ModelArgs = ModelArgs(
max_seq_len=512,
max_batch_size=1,
use_kv_cache=False,
use_sdpa_with_kv_cache_op=False,
generate_full_logits=False,
input_prune_map=None,
output_prune_map=None,
enable_dynamic_shape=False,
**params,
)
checkpoint = torch.load(checkpoint_path, map_location="cpu", mmap=True)

with torch.no_grad():
model = Transformer(model_args)
model.eval()
model.load_state_dict(
checkpoint,
strict=False,
assign=True
)

# [bs, n_local_kv_heads, seq_len, head_dim]
cache_shape = (model_args.n_layers, model_args.max_batch_size, model_args.n_heads, model_args.max_seq_len, model_args.dim // model_args.n_heads)
k_caches = torch.zeros(cache_shape, dtype=torch.float16, device="cpu")
v_caches = torch.zeros(cache_shape, dtype=torch.float16, device="cpu")


# example_inputs = (
# torch.ones(size=(1, model_args.max_seq_len), dtype=torch.long),
# torch.tensor(
# [0], dtype=torch.long
# ), # start_pos
# k_caches,
# v_caches,
# )

example_inputs = (
torch.ones(size=(1, 1), dtype=torch.long),
torch.tensor(
[0], dtype=torch.long
), # start_pos
k_caches,
v_caches,
)

exported_model = torch.export.export(model, example_inputs, strict=False)


print('Exported model', exported_model)

from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from executorch.exir import to_edge
from executorch.exir.program._program import to_edge_with_preserved_ops

edge_config = EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
)
edge_program_manager = to_edge(exported_model, compile_config=edge_config)
print('Edge program', edge_program_manager.exported_program())

from executorch.extension.llm.export.partitioner_lib import get_coreml_partitioner
partitioner = get_coreml_partitioner(ios = 18)
delegated_edge_program_manager = edge_program_manager.to_backend(partitioner)
print('Delegated edge program', delegated_edge_program_manager.exported_program())

executorch_program = delegated_edge_program_manager.to_executorch()
with open(output_path, "wb") as file:
executorch_program.write_to_file(file)


# # Convert to Core ML program using the Unified Conversion API.
# model_from_trace = ct.convert(
# traced_model,
# inputs=[ct.TensorType(shape=exi.shape) for exi in example_inputs ],
# minimum_deployment_target=ct.target.iOS18
# )
# model_from_trace.save("/Users/scroy/Desktop/traced_model.mlpackage")


# model_from_export = ct.convert(
# exported_model,
# minimum_deployment_target=ct.target.iOS18
# )
# model_from_export.save("/Users/scroy/Desktop/exported_model.mlpackage")


# mlpackage = ct.convert(exported_model, minimum_deployment_target=ct.target.iOS18)

# print(mlpackage)
# mlpackage.save("/Users/scroy/Desktop/model.mlpackage")

# desc = ct.utils.MultiFunctionDescriptor()

# path = "/Users/scroy/repos/executorch/extracted_coreml_models/model_1/lowered_module"

# desc.add_function(
# f"{path}/model_prefill.mlpackage",
# src_function_name="main",
# target_function_name="prefill"
# )
# desc.add_function(
# f"{path}/model_kv.mlpackage",
# src_function_name="main",
# target_function_name="gen"
# )

# desc.default_function_name = "prefill"
# ct.utils.save_multifunction(desc, f"{path}/combined.mlpackage")
20 changes: 19 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@

from torch import nn

@torch.library.custom_op("coreml::sdpa", mutates_args=())
def sdpa(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Same as F.scaled_dot_product_attention, but with custom op to avoid lowering during dialect conversion."""
return torch.ops.aten.scaled_dot_product_attention.default(
q, k, v, attn_mask=attn_mask
)

@torch.library.register_fake("coreml::sdpa")
def _(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor
) -> torch.Tensor:
"""Fake implementation with the right output shape, which is required for torch.compile/export/fx tracing."""
expected_shape = list(q.shape)
expected_shape[-1] = v.shape[-1]
return q.new_empty(expected_shape)

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
Expand Down Expand Up @@ -351,7 +368,8 @@ def forward(

mask = self.mask[:seqlen, :seqlen]

output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
# output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
output = torch.ops.coreml.sdpa(q, k, v, mask)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

Expand Down
Loading
Loading