Skip to content

Commit d6b957d

Browse files
committed
Merge remote-tracking branch 'origin/main' into jni-layer-llama-1
2 parents b34ce37 + ab44d06 commit d6b957d

17 files changed

+735
-48
lines changed

backends/cadence/aot/memory_constraints.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,37 @@ def compute_slice_and_select_loc_constraints(
654654
]
655655

656656

657+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
658+
class GenerateIdmaConstraints(PassBase):
659+
"""Generate constraints for idma ops."""
660+
661+
def __init__(self, constraint: MemConstraints) -> None:
662+
self.constraint = constraint
663+
664+
def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
665+
for node in graph_module.graph.find_nodes(
666+
op="call_function", target=torch.ops.cadence.idma_wait.out
667+
):
668+
# This is just an alias op.
669+
self.constraint.add_relative_placement_constraint(node.args[0], node)
670+
671+
for node in graph_module.graph.find_nodes(
672+
op="call_function", target=torch.ops.cadence.idma_load.out
673+
):
674+
# TODO: set correct dtcm bank here.
675+
mem_id = 1
676+
self.constraint.add_absolute_placement_constraint(node, mem_id, None)
677+
678+
for node in graph_module.graph.find_nodes(
679+
op="call_function", target=torch.ops.cadence.idma_store.out
680+
):
681+
# TODO: set correct dtcm bank here.
682+
mem_id = 1
683+
self.constraint.add_absolute_placement_constraint(
684+
node.args[0], mem_id, None
685+
)
686+
687+
657688
# The class to generate all the constraints that will be passed on to the memory
658689
# planning algorithm.
659690
class GenerateMemConstraints:
@@ -671,6 +702,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
671702
constraint_gen_passes: Sequence[ConstraintsGenPass] = cast(
672703
list[ConstraintsGenPass],
673704
[
705+
GenerateIdmaConstraints,
674706
GenerateMemoryViewConstraints,
675707
GenerateSliceAndSelectNopConstraints,
676708
GenerateCatNopConstraints,

backends/cadence/aot/memory_planning.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import collections
1010
import itertools
1111
import logging
12-
from typing import Iterable, Optional, Sequence
12+
from typing import Callable, Iterable, Optional, Sequence, TypeAlias
1313

1414
import torch
1515
from executorch.backends.cadence.aot.memory_constraints import MemConstraints
@@ -26,6 +26,8 @@
2626

2727
from executorch.exir import ExecutorchProgramManager
2828
from executorch.exir.memory_planning import collect_specs_from_nodes, Verifier
29+
from executorch.exir.pass_base import PassBase
30+
from executorch.exir.pass_manager import PassManager
2931
from executorch.exir.passes import MemoryPlanningPass
3032
from executorch.exir.tensor import TensorSpec
3133
from tabulate import tabulate
@@ -359,6 +361,35 @@ def print_memory_planning_info(
359361
)
360362

361363

364+
class SimplifyIdmaOpsPass(PassBase):
365+
"""Replace idma_load and idma_store with idma_copy."""
366+
367+
def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
368+
modified = False
369+
for node in graph_module.graph.find_nodes(
370+
op="call_function", target=torch.ops.cadence.idma_load.out
371+
):
372+
modified = True
373+
node.target = torch.ops.cadence.idma_copy.out
374+
node.args = (node.args[0], *node.args[2:])
375+
376+
for node in graph_module.graph.find_nodes(
377+
op="call_function", target=torch.ops.cadence.idma_store.out
378+
):
379+
modified = True
380+
node.target = torch.ops.cadence.idma_copy.out
381+
382+
graph_module.graph.eliminate_dead_code()
383+
graph_module.recompile()
384+
return PassResult(graph_module, modified)
385+
386+
387+
ConstraintGenPassType: TypeAlias = Callable[
388+
[MemConstraints],
389+
Callable[[torch.fx.GraphModule], Optional[PassResult]],
390+
]
391+
392+
362393
class CadenceMemoryPlanning:
363394
def __init__(
364395
self,
@@ -423,10 +454,16 @@ def run(
423454
# True.
424455
mem_planning = MemoryPlanningPass(
425456
self.algo,
426-
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
457+
# Always allow lifetime and storage overlap.
458+
# At opt level 0, we need overlap for idma wait.
459+
allow_lifetime_and_storage_overlap=True,
427460
alloc_graph_input=self.alloc_graph_input,
428461
alloc_graph_output=self.alloc_graph_output,
429462
)
430463
mem_planning.run(graph_module, graph_signature)
431464

465+
graph_module = PassManager(passes=[SimplifyIdmaOpsPass()])(
466+
graph_module
467+
).graph_module
468+
432469
return PassResult(graph_module, True)

backends/cadence/aot/ops_registrations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,13 @@
304304
# Post memory planning, we check that outputs/inputs for the load/store are in
305305
# DTCM and replace idma_load/idma_store with idma_copy.
306306
lib.define("idma_load(Tensor src, int task_num=0, int channel=0) -> Tensor")
307+
lib.define(
308+
"idma_load.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
309+
)
307310
lib.define("idma_store(Tensor src, int task_num=0, int channel=0) -> Tensor")
311+
lib.define(
312+
"idma_store.out(Tensor src, int task_num=0, int channel=0, *, Tensor(a!) out) -> Tensor(a!)"
313+
)
308314

309315
# Non-blocking iDMA copy.
310316
lib.define("idma_copy(Tensor src, int task_num=0, int channel=0) -> Tensor")

examples/models/llava/export_llava.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ def export_all(llava_model: LlavaModel):
226226
{
227227
"image_encoder": image_encoder_ep,
228228
"token_embedding": token_embedding_ep,
229-
"text_model": text_model_ep,
229+
"text_decoder": text_model_ep,
230230
},
231231
partitioner={
232232
"image_encoder": [XnnpackPartitioner()],
233-
"text_model": [
233+
"text_decoder": [
234234
# First partition the DQLinear nodes, then partition the rest of the nodes,
235235
# to avoid multiple DQLinear nodes in the same partition,
236236
# to avoid holding multiple unpacked and packed weight buffers in memory,
@@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel):
254254
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
255255
sym_shape_eval_pass={
256256
"image_encoder": ConstraintBasedSymShapeEvalPass(),
257-
"text_model": ConstraintBasedSymShapeEvalPass(),
257+
"text_decoder": ConstraintBasedSymShapeEvalPass(),
258258
"token_embedding": HintBasedSymShapeEvalPass(),
259259
},
260260
)

examples/models/llava/runner/llava_text_decoder_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner
8989
}
9090

9191
inline static const std::string kTokenEmbeddingMethod = "token_embedding";
92-
inline static const std::string kTextModelMethod = "text_model";
92+
inline static const std::string kTextModelMethod = "text_decoder";
9393
};
9494

9595
} // namespace example

examples/models/llava/test/test_llava.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_llava_export(self):
9696
"token_embedding", (prompt_before_image,)
9797
)[0]
9898
llava_module.run_method(
99-
"text_model",
99+
"text_decoder",
100100
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
101101
)
102102

@@ -107,7 +107,7 @@ def test_llava_export(self):
107107
# pte prefill image
108108
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
109109
llava_module.run_method(
110-
"text_model",
110+
"text_decoder",
111111
(
112112
torch.tensor([start_pos], dtype=torch.int64),
113113
pte_embeds_img,
@@ -122,7 +122,7 @@ def test_llava_export(self):
122122
"token_embedding", (prompt_after_image,)
123123
)[0]
124124
pte_prefill_after_img = llava_module.run_method(
125-
"text_model",
125+
"text_decoder",
126126
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
127127
)[0]
128128

@@ -139,7 +139,7 @@ def test_llava_export(self):
139139
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
140140
)[0]
141141
logits = llava_module.run_method(
142-
"text_model",
142+
"text_decoder",
143143
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
144144
)[0]
145145
new_tokens.append(torch.argmax(logits).item())

examples/models/llava/test/test_pte.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def main():
4747
"token_embedding", (prompt_before_image,)
4848
)[0]
4949
pte_prefill_before_img = llava_module.run_method(
50-
"text_model",
50+
"text_decoder",
5151
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
5252
)[0]
5353
print(pte_prefill_before_img)
@@ -60,7 +60,7 @@ def main():
6060
logging.warning("Image encoder finished")
6161
logging.warning("Image token prefill started")
6262
pte_prefill_img = llava_module.run_method(
63-
"text_model",
63+
"text_decoder",
6464
(
6565
torch.tensor([start_pos], dtype=torch.int64),
6666
pte_embeds_img,
@@ -77,7 +77,7 @@ def main():
7777
"token_embedding", (prompt_after_image,)
7878
)[0]
7979
pte_prefill_after_img = llava_module.run_method(
80-
"text_model",
80+
"text_decoder",
8181
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
8282
)[0]
8383
logging.warning("Text token prefill finished")
@@ -91,7 +91,7 @@ def main():
9191
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
9292
)[0]
9393
logits = llava_module.run_method(
94-
"text_model",
94+
"text_decoder",
9595
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
9696
)[0]
9797
new_tokens.append(torch.argmax(logits[..., -1, :]).item())
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# Simple CMake build system for voxtral runner.
9+
#
10+
cmake_minimum_required(VERSION 3.24)
11+
project(voxtral)
12+
13+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
14+
15+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
16+
17+
if(CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$")
18+
set(CMAKE_TOOLCHAIN_IOS ON)
19+
else()
20+
set(CMAKE_TOOLCHAIN_IOS OFF)
21+
endif()
22+
23+
# Let files say "include <executorch/path/to/header.h>"
24+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
25+
26+
# Need this for gflags for some reason
27+
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
28+
find_package(gflags REQUIRED)
29+
30+
# Find `executorch` libraries, same as for gflags
31+
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
32+
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
33+
executorch_target_link_options_shared_lib(executorch)
34+
35+
set(LINK_LIBS executorch gflags)
36+
set(link_libraries ${LINK_LIBS})
37+
set(_srcs multimodal.cpp)
38+
39+
list(
40+
APPEND
41+
link_libraries
42+
optimized_native_cpu_ops_lib
43+
quantized_ops_lib
44+
custom_ops
45+
cpublas
46+
eigen_blas
47+
)
48+
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)
49+
executorch_target_link_options_shared_lib(quantized_ops_lib)
50+
executorch_target_link_options_shared_lib(custom_ops)
51+
52+
# XNNPACK
53+
if(TARGET xnnpack_backend)
54+
set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod)
55+
if(TARGET kleidiai)
56+
list(APPEND xnnpack_backend_libs kleidiai)
57+
endif()
58+
list(APPEND link_libraries ${xnnpack_backend_libs})
59+
executorch_target_link_options_shared_lib(xnnpack_backend)
60+
endif()
61+
62+
# Add LLM runner and extension module
63+
if(NOT TARGET extension_llm_runner)
64+
message(
65+
FATAL_ERROR
66+
"ExecuTorch must be installed with EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER enabled."
67+
)
68+
endif()
69+
70+
# Needed for cpuinfo where it uses android specific log lib
71+
if(ANDROID)
72+
list(APPEND link_libraries log)
73+
endif()
74+
75+
# Add the required ExecutorTorch extensions for multimodal LLM runner
76+
list(
77+
APPEND
78+
link_libraries
79+
extension_llm_runner
80+
extension_module
81+
extension_data_loader
82+
extension_tensor
83+
extension_flat_tensor
84+
)
85+
86+
# Add tokenizers
87+
list(APPEND link_libraries tokenizers::tokenizers)
88+
89+
add_executable(voxtral_runner ${_srcs})
90+
if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug")
91+
target_link_options_gc_sections(voxtral_runner)
92+
if(NOT APPLE)
93+
target_link_options(voxtral_runner PRIVATE "LINKER:-s")
94+
endif()
95+
endif()
96+
97+
target_include_directories(voxtral_runner PUBLIC ${_common_include_directories})
98+
target_link_libraries(voxtral_runner PUBLIC ${link_libraries})
99+
target_compile_options(voxtral_runner PUBLIC ${_common_compile_options})

0 commit comments

Comments
 (0)