Skip to content

Commit 56f84cc

Browse files
committed
Update
[ghstack-poisoned]
2 parents 9d1824d + b3944b3 commit 56f84cc

File tree

12 files changed

+178
-26
lines changed

12 files changed

+178
-26
lines changed

CMakeLists.txt

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -647,18 +647,14 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM)
647647
list(APPEND _executorch_extensions tokenizers)
648648
endif()
649649

650-
if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
651-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
652-
list(APPEND _executorch_extensions extension_llm_apple)
653-
endif()
654-
655650
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
656651
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
657652
list(APPEND _executorch_extensions extension_llm_runner)
658653
endif()
659654

660655
if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
661656
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
657+
list(APPEND _executorch_extensions extension_llm_apple)
662658
endif()
663659

664660
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
@@ -802,7 +798,6 @@ if(EXECUTORCH_BUILD_VGF)
802798
list(APPEND _executorch_backends vgf_backend)
803799
endif()
804800

805-
806801
# Top-level interface targets.
807802

808803
# A target containing all configured backends.
@@ -869,17 +864,10 @@ else()
869864
endif()
870865
target_link_libraries(executorch_kernels INTERFACE ${_executorch_kernels})
871866

872-
install(
873-
TARGETS executorch_backends executorch_extensions executorch_kernels
874-
INCLUDES
875-
DESTINATION ${_common_include_directories}
876-
)
877-
878867
if(EXECUTORCH_BUILD_EXECUTOR_RUNNER)
879868
# Baseline libraries that executor_runner will link against.
880869
set(_executor_runner_libs executorch extension_evalue_util
881-
extension_runner_util gflags
882-
executorch_backends
870+
extension_runner_util gflags executorch_backends
883871
)
884872

885873
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)

backends/xnnpack/operators/node_visitor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,9 +622,10 @@ def get_serialized_buffer_index(
622622
)
623623

624624
external_tag = tensor.meta.get("delegate_constant_tag", None)
625-
logging.info(
626-
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
627-
)
625+
if external_tag is not None:
626+
logging.info(
627+
f"Adding constant data with name {tensor.name}, key {named_key} and external_tag {external_tag} to named_data_store"
628+
)
628629
self._named_data_store.add_named_data(
629630
named_key,
630631
bytes(array),

devtools/etrecord/tests/etrecord_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,17 @@ def check_graph_closeness(self, graph_a, graph_b):
9292
self.assertEqual(
9393
node_a.meta.get("debug_handle"), node_b.meta.get("debug_handle")
9494
)
95+
from_node_a = node_a.meta.get("from_node")
96+
from_node_b = node_b.meta.get("from_node")
97+
98+
if from_node_a is None:
99+
self.assertIsNone(from_node_b)
100+
else:
101+
self.assertIsNotNone(from_node_b)
102+
for node_source_a, node_source_b in zip(from_node_a, from_node_b):
103+
self.assertEqual(
104+
node_source_a.to_dict(), node_source_b.to_dict()
105+
)
95106

96107
def test_etrecord_generation(self):
97108
captured_output, edge_output, et_output = self.get_test_model()

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,18 @@ def build_args_parser() -> argparse.ArgumentParser:
239239
help="checkpoint directory. Use with a sharded checkpoint, not for the standard llama2 model. Note, checkpoint_dir takes precedence over checkpoint if both are set.",
240240
)
241241

242+
parser.add_argument(
243+
"--adapter_checkpoint",
244+
required=False,
245+
help="Path to the adapter.pt file from torchtune. Used if the model has trained LoRA adapters. Must provide adapter_config.json",
246+
)
247+
248+
parser.add_argument(
249+
"--adapter_config",
250+
required=False,
251+
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
252+
)
253+
242254
parser.add_argument(
243255
"--use_qnn_sha",
244256
action="store_true",

examples/models/llama/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,13 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
4646
checkpoint_dir = self.llm_config.base.checkpoint_dir
4747
params_path = self.llm_config.base.params
4848

49+
# Adapter checkpoint and config.
50+
adapter_checkpoint_path = self.llm_config.base.adapter_checkpoint
51+
adapter_config_path = self.llm_config.base.adapter_config
52+
assert (adapter_checkpoint_path is None and adapter_config_path is None) or (
53+
adapter_checkpoint_path is not None and adapter_config_path is not None
54+
), "Both adapter_checkpoint_path and adapter_config_path must be specified or neither must be specified."
55+
4956
self.use_kv_cache = self.llm_config.model.use_kv_cache
5057
self.use_sdpa_with_kv_cache_op = self.llm_config.model.use_sdpa_with_kv_cache
5158
self.generate_full_logits = self.llm_config.debug.generate_full_logits
@@ -129,6 +136,20 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
129136
with open(params_path, "r") as f:
130137
params = json.loads(f.read())
131138

139+
# Get adapter checkpoint and config.
140+
adapter_checkpoint = {}
141+
adapter_config = {}
142+
if adapter_checkpoint_path:
143+
adapter_checkpoint = torch.load(
144+
adapter_checkpoint_path, map_location=device, mmap=True
145+
)
146+
from torchtune.models import convert_weights
147+
148+
adapter_checkpoint = convert_weights.tune_to_meta(adapter_checkpoint)
149+
with open(adapter_config_path, "r") as f:
150+
adapter_config = json.loads(f.read())
151+
checkpoint.update(adapter_checkpoint)
152+
132153
output_prune_map = None
133154
if self.output_prune_map_path is not None:
134155
with open(self.output_prune_map_path, "r") as f:
@@ -153,6 +174,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
153174
output_prune_map=output_prune_map,
154175
enable_dynamic_shape=self.enable_dynamic_shape,
155176
**params,
177+
**adapter_config,
156178
)
157179

158180
if model_args.use_scaled_rope:

examples/models/llama/model_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class ModelArgs:
5959
lora_args: Optional[dict] = None
6060

6161
# LoRA arguments to set up a LoRA inference model.
62-
# These arguments come directly from a torchtune LoRA config.
62+
# These arguments come directly from a torchtune adapter_config.json file.
6363
r: Optional[int] = None # Rank.
6464
lora_alpha: Optional[int] = None # Alpha.
6565
# Eg. q_proj, k_proj, v_proj, output_proj

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,39 @@ class StaticAttentionIOManager {
602602
}
603603
}
604604

605+
/**
606+
* Prefill helper. Run multiple inferences as needed depending on the length
607+
* of the prompt and method's input length. Returns the position in the output
608+
* that corresponds to the end of the prompt during the last inference.
609+
*/
610+
template <typename TokenT>
611+
size_t prefill(
612+
executorch::runtime::Span<TokenT> tokens,
613+
executorch::runtime::Span<TokenT> input_buffer,
614+
executorch::runtime::Method& method) {
615+
size_t input_len = input_buffer.size();
616+
get_mask(input_buffer.size()).set_causal_mask();
617+
618+
size_t batch_len = 0;
619+
for (size_t i = 0; i < tokens.size(); i += input_len) {
620+
batch_len = std::min(input_len, tokens.size() - i);
621+
std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin());
622+
prepare(method);
623+
ET_CHECK(method.execute() == executorch::runtime::Error::Ok);
624+
update(
625+
method,
626+
config_.k_cache_output_indices,
627+
config_.v_cache_output_indices,
628+
batch_len);
629+
}
630+
return batch_len - 1;
631+
}
632+
633+
/**
634+
* Decode helper. The `sample` argument is called after each inference and
635+
* should retrieve the logits from the `method` argument's output and return
636+
* the sampled token.
637+
*/
605638
template <typename TokenT>
606639
std::vector<TokenT> decode(
607640
TokenT prev_tok,
@@ -632,6 +665,11 @@ class StaticAttentionIOManager {
632665
return generated_tokens;
633666
}
634667

668+
/**
669+
* Lookahead decode helper. The `sample` argument is called after each
670+
* inference and should retrieve the logits from the `method` argument's
671+
* output and return the sampled token for all output positions.
672+
*/
635673
template <typename TokenT>
636674
std::vector<TokenT> lookahead_decode(
637675
TokenT prev_tok,

exir/serde/serialize.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from torch._export.verifier import load_verifier
4343
from torch.fx.experimental import symbolic_shapes
44+
from torch.fx.traceback import NodeSource
4445

4546
log: logging.Logger = logging.getLogger(__name__)
4647

@@ -141,8 +142,24 @@ def serialize_metadata(self, node: torch.fx.Node) -> Dict[str, str]:
141142
debug_handle = node.meta["debug_handle"]
142143
meta["debug_handle"] = str(debug_handle)
143144

145+
if "from_node" in node.meta:
146+
from_node = node.meta["from_node"]
147+
# Serialize from_node as JSON since it's a complex nested structure
148+
meta["from_node"] = json.dumps(self._make_from_node_json_acceptable(from_node))
149+
144150
return meta
145151

152+
def _make_from_node_json_acceptable(self, from_node: Optional[List[NodeSource]]):
153+
"""
154+
Serialize from_node metadata from a list of NodeSource objects to a list of dictionaries.
155+
"""
156+
if from_node is None:
157+
return None
158+
159+
json_acceptable_from_node = [node_source.to_dict() for node_source in from_node if isinstance(node_source, NodeSource)]
160+
161+
return json_acceptable_from_node
162+
146163
def serialize_alloc_inputs(
147164
self, inputs # pyre-ignore
148165
) -> List[schema.NamedArgument]:
@@ -473,8 +490,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]:
473490
if debug_handle := metadata.get("debug_handle"):
474491
res["debug_handle"] = int(debug_handle)
475492

493+
if from_node_str := metadata.get("from_node"):
494+
res["from_node"] = self._deserialize_from_node(json.loads(from_node_str))
495+
476496
return res
477497

498+
def _deserialize_from_node(self, from_node_data: Optional[List[Dict[str, Any]]]) -> Optional[List[NodeSource]]:
499+
"""
500+
Recursively deserialize from_node metadata from JSON data.
501+
"""
502+
if from_node_data is None:
503+
return None
504+
505+
assert isinstance(from_node_data, list)
506+
507+
return [NodeSource._from_dict(fn_dict) for fn_dict in from_node_data]
508+
478509
# pyre-ignore
479510
def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]):
480511
def deserialize_alloc_spec(serialized_alloc_spec: str) -> memory.AllocSpec:

exir/tests/test_serde.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,37 @@ def forward(self, x):
275275
)
276276
self.assertEqual(metadata[0], metadata_serde[0])
277277
self.assertEqual(list(metadata[1].keys()), list(metadata_serde[1].keys()))
278+
279+
def test_meta_debug_handle_and_from_node(self) -> None:
280+
class Model(nn.Module):
281+
def __init__(self):
282+
super(Model, self).__init__()
283+
self.conv_layer = nn.Conv2d(
284+
in_channels=1, out_channels=64, kernel_size=3, padding=1
285+
)
286+
287+
def forward(self, x):
288+
return self.conv_layer(x)
289+
290+
m = Model()
291+
inputs = (torch.randn(1, 1, 32, 32),)
292+
293+
edge = to_edge(export(m, inputs, strict=True))
294+
edge_new = deserialize(serialize(edge.exported_program()))
295+
for node, node_new in zip(
296+
edge.exported_program().graph_module.graph.nodes,
297+
edge_new.graph_module.graph.nodes,
298+
):
299+
if node.op not in {"placeholder", "output"}:
300+
self.assertIsNotNone(node.meta.get("debug_handle"))
301+
self.assertIsNotNone(node.meta.get("from_node"))
302+
self.assertEqual(
303+
node.meta.get("debug_handle"), node_new.meta.get("debug_handle")
304+
)
305+
self.assertEqual(
306+
len(node.meta.get("from_node")), len(node_new.meta.get("from_node"))
307+
)
308+
for node_source, node_source_new in zip(
309+
node.meta.get("from_node"), node_new.meta.get("from_node")
310+
):
311+
self.assertEqual(node_source.to_dict(), node_source_new.to_dict())

extension/llm/export/config/llm_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,16 @@ class BaseConfig:
7373
if it is a Llama model or the weights will be downloaded from HuggingFace
7474
if it is a non-Llama model.
7575
checkpoint_dir: Path to directory containing sharded checkpoint files.
76+
adapter_checkpoint: Path to the adapter.pt file from torchtune. Used if
77+
the model has trained LoRA adapters. Must provide
78+
adapter_config.json.
79+
adapter_config: Path to the adapter_config.json file from torchtune.
80+
Used if the model has trained LoRA adapters. Must provide adapter.pt.
7681
tokenizer_path: Path to the tokenizer file.
7782
metadata: Json string containing metadata information.
7883
e.g. '"{\"get_bos_id\":128000, \"get_eos_ids\":[128009, 128001]}"'
79-
use_lora: Rank of the LoRA, if set to 0 then this means no LoRA. For use with QAT.
84+
use_lora: Only for use with QAT. Rank of the LoRA adapter, disabled
85+
if set to 0.
8086
fairseq2: For legacy internal use cases, this is safe to ignore.
8187
preq_mode: Legacy option to specify how prequantized weights are loaded.
8288
Going forward, ExecuTorch supports loading weights prequantized through
@@ -90,6 +96,8 @@ class BaseConfig:
9096
params: Optional[str] = None
9197
checkpoint: Optional[str] = None
9298
checkpoint_dir: Optional[str] = None
99+
adapter_checkpoint: Optional[str] = None
100+
adapter_config: Optional[str] = None
93101
tokenizer_path: Optional[str] = None
94102
metadata: Optional[str] = None
95103
use_lora: int = 0

0 commit comments

Comments
 (0)