Skip to content

Commit 7ee03fc

Browse files
committed
Merge remote-tracking branch 'origin/main' into use-executorch-core
2 parents cd26efa + e60958a commit 7ee03fc

File tree

15 files changed

+267
-28
lines changed

15 files changed

+267
-28
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.nn.functional as F
1414

1515
from executorch.examples.models.llama.attention import (
16+
Attention,
1617
ATTENTION_REGISTRY,
1718
ForwardOptions,
1819
)
@@ -83,26 +84,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8384

8485

8586
class TransformerBlock(nn.Module):
86-
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
87+
def __init__(self, args: ModelArgs, attention: Attention):
88+
"""
89+
Transformer block with support for pre-norm and post-norm.
90+
Args:
91+
args (ModelArgs): model configuration parameters.
92+
attention (Attention): attention object to use in the transformer
93+
block. See `attention.py` for types of attention. Make sure
94+
the attention type is registered in the ATTENTION_REGISTRY.
95+
"""
8796
super().__init__()
8897
self.use_kv_cache = args.use_kv_cache
8998
self.n_heads = args.n_heads
9099
self.dim = args.dim
91100
self.head_dim = args.head_dim
92-
if args.attention_type not in ATTENTION_REGISTRY:
93-
raise ValueError(
94-
f"Unknown attention type: {args.attention_type}. "
95-
f"Available: {list(ATTENTION_REGISTRY.keys())}"
96-
)
97-
cls = ATTENTION_REGISTRY[args.attention_type]
98-
self.attention = cls(args, layer_id, rope)
101+
self.attention = attention
99102
if args.moe:
100103
self.block_sparse_moe = MOEFeedForward(args)
101104
else:
102105
self.feed_forward = FeedForward(args)
103106
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
104107
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
105108

109+
@classmethod
110+
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
111+
"""
112+
Create a TransformerBlock with the legacy constructor.
113+
Args:
114+
layer_id (int): the index of the layer.
115+
args (ModelArgs): model configuration parameters.
116+
rope (Rope): the rope object to use for rotary embeddings.
117+
"""
118+
if args.attention_type not in ATTENTION_REGISTRY:
119+
raise ValueError(
120+
f"Unknown attention type: {args.attention_type}. "
121+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
122+
)
123+
cls = ATTENTION_REGISTRY[args.attention_type]
124+
attention = cls(args, layer_id, rope)
125+
return TransformerBlock(args, attention)
126+
106127
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
107128
h, attn_options_update = self.attention.forward(
108129
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
@@ -117,7 +138,15 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117138

118139

119140
class Transformer(nn.Module):
120-
def __init__(self, params: ModelArgs):
141+
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
142+
"""
143+
Transformer model.
144+
Args:
145+
params (ModelArgs): model configuration parameters.
146+
layers (nn.ModuleList): list of transformer blocks - see the
147+
`TransformerBlock` type above.
148+
rope (Rope): the rope object to use for rotary embeddings.
149+
"""
121150
super().__init__()
122151
self.params = params
123152
self.vocab_size = params.vocab_size
@@ -130,10 +159,8 @@ def __init__(self, params: ModelArgs):
130159
if self.apply_embedding
131160
else None
132161
)
133-
self.rope = Rope(params)
134-
self.layers = torch.nn.ModuleList()
135-
for layer_id in range(params.n_layers):
136-
self.layers.append(TransformerBlock(layer_id, params, self.rope))
162+
self.layers = layers
163+
self.rope = rope
137164
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
138165
self.output = (
139166
nn.Linear(params.dim, params.vocab_size, bias=False)
@@ -212,3 +239,23 @@ def forward(
212239
return logits, attn_options_update
213240

214241
return logits
242+
243+
244+
def construct_transformer(model_args: ModelArgs) -> Transformer:
245+
"""
246+
Construct a Transformer model from the given model arguments.
247+
"""
248+
rope = Rope(model_args)
249+
if model_args.attention_type not in ATTENTION_REGISTRY:
250+
raise ValueError(
251+
f"Unknown attention type: {model_args.attention_type}. "
252+
f"Available: {list(ATTENTION_REGISTRY.keys())}"
253+
)
254+
layers = torch.nn.ModuleList()
255+
cls = ATTENTION_REGISTRY[model_args.attention_type]
256+
for layer_id in range(model_args.n_layers):
257+
attention = cls(model_args, layer_id, rope)
258+
transformer_block = TransformerBlock(model_args, attention)
259+
layers.append(transformer_block)
260+
261+
return Transformer(model_args, layers, rope)

examples/models/llama/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
get_checkpoint_dtype,
1616
get_default_model_resource_dir,
1717
)
18-
from executorch.examples.models.llama.llama_transformer import Transformer
1918

19+
from executorch.examples.models.llama.llama_transformer import construct_transformer
2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from executorch.examples.models.llama.rope import Rope
2122
from torchao.utils import TorchAOBaseTensor
2223

2324
try:
@@ -174,7 +175,7 @@ def __init__(self, **kwargs):
174175
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
175176
with torch.device("meta"):
176177
# Model itself is loaded in default dtype, fp32.
177-
self.model_ = Transformer(model_args)
178+
self.model_ = construct_transformer(model_args)
178179
# Get checkpoint dtype.
179180
if checkpoint:
180181
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)

examples/models/llama/tests/test_pre_quantization_transforms.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
import unittest
88

99
import torch
10-
from executorch.examples.models.llama.llama_transformer import Transformer
10+
from executorch.examples.models.llama.llama_transformer import (
11+
construct_transformer,
12+
Transformer,
13+
)
1114
from executorch.examples.models.llama.model_args import ModelArgs
1215
from executorch.examples.models.llama.source_transformation.pre_quantization import (
1316
sanitize_checkpoint_from_pre_quantization,
@@ -39,7 +42,7 @@ def _prepare_dummy_model(self) -> Transformer:
3942
vocab_size=32000,
4043
)
4144

42-
model = Transformer(model_args)
45+
model = construct_transformer(model_args)
4346

4447
return model
4548

examples/models/llama/tests/test_static_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44
from executorch.examples.models.llama.attention import AttentionMHA, ForwardOptions
5-
from executorch.examples.models.llama.llama_transformer import Transformer
5+
from executorch.examples.models.llama.llama_transformer import construct_transformer
66
from executorch.examples.models.llama.model_args import ModelArgs
77
from executorch.examples.models.llama.rope import Rope
88
from executorch.examples.models.llama.static_attention import (
@@ -160,10 +160,10 @@ def test_within_transformer(self):
160160
n_layers=4,
161161
vocab_size=128,
162162
)
163-
mha_transformer = Transformer(config).eval()
163+
mha_transformer = construct_transformer(config).eval()
164164

165165
config.attention_type = "static"
166-
static_transformer = Transformer(config).eval()
166+
static_transformer = construct_transformer(config).eval()
167167
static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False)
168168
for mha_layer, static_layer in zip(
169169
mha_transformer.layers, static_transformer.layers

examples/models/llava/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import requests
1414
import torch
15-
from executorch.examples.models.llama.llama_transformer import Transformer
15+
from executorch.examples.models.llama.llama_transformer import construct_transformer
1616
from executorch.examples.models.llama.model_args import ModelArgs
1717

1818
from executorch.examples.models.llama.source_transformation.custom_kv_cache import (
@@ -66,7 +66,7 @@ def __init__(
6666
use_hf_rope=True,
6767
max_seq_len=max_seq_len,
6868
)
69-
self.text_model = Transformer(self.text_model_args)
69+
self.text_model = construct_transformer(self.text_model_args)
7070
# use custom op for SDPA.
7171
if use_sdpa_with_kv_cache_op:
7272
self.text_model = replace_kv_cache_with_custom_kv_cache(self.text_model)

exir/backend/test/demos/rpc/ExecutorBackend.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <executorch/runtime/core/error.h>
1919
#include <executorch/runtime/core/evalue.h>
2020
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
21+
#include <executorch/runtime/core/named_data_map.h>
2122
#include <executorch/runtime/executor/method.h>
2223
#include <executorch/runtime/executor/program.h>
2324

@@ -37,6 +38,7 @@ using ::executorch::runtime::MemoryAllocator;
3738
using ::executorch::runtime::MemoryManager;
3839
using ::executorch::runtime::Method;
3940
using ::executorch::runtime::MethodMeta;
41+
using ::executorch::runtime::NamedDataMap;
4042
using ::executorch::runtime::Program;
4143
using ::executorch::runtime::Result;
4244
using ::executorch::runtime::Span;
@@ -156,9 +158,13 @@ class ExecutorBackend final : public ::executorch::runtime::BackendInterface {
156158
new (client_memory_manager)
157159
MemoryManager(client_method_allocator, client_planned_memory);
158160

161+
const NamedDataMap* named_data_map = context.get_named_data_map();
159162
// Construct the client Method
160-
Result<Method> method_res =
161-
client_program->load_method("forward", client_memory_manager);
163+
Result<Method> method_res = client_program->load_method(
164+
"forward",
165+
client_memory_manager,
166+
/*event_tracer=*/nullptr,
167+
named_data_map);
162168
if (!method_res.ok()) {
163169
ET_LOG(
164170
Error,

exir/backend/test/demos/rpc/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ runtime.python_library(
1111
],
1212
visibility = [
1313
"//executorch/exir/backend/test/...",
14+
"//executorch/test/...",
1415
],
1516
deps = [
1617
"//caffe2:torch",

exir/backend/test/demos/rpc/executor_backend_preprocess.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from typing import final, List
1010

11+
from executorch.exir import ExecutorchBackendConfig
12+
1113
from executorch.exir.backend.backend_details import (
1214
BackendDetails,
1315
ExportedProgram,
@@ -24,10 +26,14 @@ def preprocess(
2426
edge_program: ExportedProgram,
2527
compile_specs: List[CompileSpec],
2628
) -> PreprocessResult:
29+
config = ExecutorchBackendConfig()
30+
for spec in compile_specs:
31+
if spec.key == "external_constants":
32+
config.external_constants = True
2733
return PreprocessResult(
2834
processed_bytes=EdgeProgramManager(
2935
edge_programs=edge_program,
3036
)
31-
.to_executorch()
37+
.to_executorch(config)
3238
.buffer,
3339
)

exir/backend/test/demos/rpc/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def define_common_targets():
4040
],
4141
visibility = [
4242
"//executorch/exir/backend/test/...",
43+
"//executorch/runtime/executor/test/...",
4344
],
4445
deps = [
4546
":executor_backend",

runtime/executor/method.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,8 @@ Result<size_t> Method::get_num_external_constants() {
329329
}
330330

331331
Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
332+
ET_CHECK_OR_RETURN_ERROR(
333+
named_data_map != nullptr, InvalidState, "named_data_map is null");
332334
auto flatbuffer_values = serialization_plan_->values();
333335
size_t n_value = flatbuffer_values->size();
334336

@@ -372,6 +374,7 @@ Error Method::parse_external_constants(const NamedDataMap* named_data_map) {
372374
Result<const TensorLayout> tensor_layout =
373375
named_data_map->get_metadata(key);
374376
if (!tensor_layout.ok()) {
377+
ET_LOG(Info, "Failed to get metadata for key %s", key);
375378
return tensor_layout.error();
376379
}
377380
// Check external tensor compatibility.

0 commit comments

Comments
 (0)