Skip to content

Commit 6693a02

Browse files
committed
Update on "add attention_sink.py"
This PR adds `KVCacheWithAttentionSink`, which is required for `AttentionSink`. It keeps the first `sink_size` tokens as attention sinks and maintains a sliding window with `window_size` for new tokens. Note: I am trying to implement and verify `AttentionSink` in eager mode first. So the current implementation may still have some lower errors or performance issue. For example, it does not support the case when dynamic shape is disabled. Will leave these problems to resolve when we are ready to deploy `AttentionSink` to edge. Differential Revision: [D65235798](https://our.internmc.facebook.com/intern/diff/D65235798/) [ghstack-poisoned]
2 parents dbbaa85 + 2c9df8e commit 6693a02

File tree

82 files changed

+2096
-948
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+2096
-948
lines changed

CONTRIBUTING.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,15 @@ for basics.
283283
- If the reviewers have requests or questions, follow up with them.
284284
- The goal of the reviewer is to ensure that the code in the `main` branch of
285285
the repo is consistent, maintainable, and of high quality.
286-
1. Once approved, your reviewer will import the PR into Meta's internal system
287-
and merge it from there.
288-
- If the PR is approved and not merged within a few business days, please
289-
comment on the PR to ask about its status.
286+
1. Once the PR has been approved,
287+
- If you have the "write permission" in this repo, you can merge it yourself
288+
by clicking the "Squash and merge" button once it is green and all CI
289+
signals are passing.
290+
- If you don't have "write permission" in this repo, the reviewer will take
291+
care of the PR. The reviewer may import the PR into Meta's internal system
292+
to validate it against internal CI.
293+
- If the PR is approved but not merged within 5 business days, please comment
294+
on the PR to ask about its status.
290295
- Note that if the `main` [CI](#continuous-integration) jobs are broken, we
291296
will only merge PRs that fix the broken jobs until all critical jobs are
292297
fixed.

backends/apple/coreml/runtime/delegate/ETCoreMLModelCompiler.mm

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,38 @@ + (nullable NSURL *)compileModelAtURL:(NSURL *)modelURL
2626
#else
2727
__block NSError *localError = nil;
2828
__block NSURL *result = nil;
29-
30-
dispatch_semaphore_t sema = dispatch_semaphore_create(0);
31-
[MLModel compileModelAtURL:modelURL completionHandler:^(NSURL * _Nullable tempURL, NSError * _Nullable compilationError) {
32-
result = [tempURL copy];
33-
localError = compilationError;
34-
dispatch_semaphore_signal(sema);
35-
}];
36-
37-
long status = dispatch_semaphore_wait(sema, dispatch_time(DISPATCH_TIME_NOW, (int64_t)(maxWaitTimeInSeconds * NSEC_PER_SEC)));
38-
if (status != 0) {
29+
30+
if (@available(iOS 16, macOS 13, watchOS 9, tvOS 16, *)) {
31+
dispatch_semaphore_t sema = dispatch_semaphore_create(0);
32+
[MLModel compileModelAtURL:modelURL completionHandler:^(NSURL * _Nullable tempURL, NSError * _Nullable compilationError) {
33+
result = [tempURL copy];
34+
localError = compilationError;
35+
dispatch_semaphore_signal(sema);
36+
}];
37+
38+
long status = dispatch_semaphore_wait(sema, dispatch_time(DISPATCH_TIME_NOW, (int64_t)(maxWaitTimeInSeconds * NSEC_PER_SEC)));
39+
if (status != 0) {
40+
ETCoreMLLogErrorAndSetNSError(error,
41+
ETCoreMLErrorCompilationFailed,
42+
"%@: Failed to compile model in %f seconds.",
43+
NSStringFromClass(ETCoreMLModelCompiler.class),
44+
maxWaitTimeInSeconds);
45+
return nil;
46+
}
47+
} else {
48+
result = [MLModel compileModelAtURL:modelURL error:&localError];
49+
}
50+
51+
if (localError) {
3952
ETCoreMLLogErrorAndSetNSError(error,
40-
ETCoreMLErrorCompilationFailed,
41-
"%@: Failed to compile model in %f seconds.",
42-
NSStringFromClass(ETCoreMLModelCompiler.class),
43-
maxWaitTimeInSeconds);
53+
ETCoreMLErrorCompilationFailed,
54+
"%@: Failed to compile model, error: %@",
55+
NSStringFromClass(ETCoreMLModelCompiler.class),
56+
localError);
4457
return nil;
58+
} else {
59+
return result;
4560
}
46-
47-
return result;
4861
#endif
4962
}
5063

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ rm -rf "$COREML_DIR_PATH/third-party"
2424
mkdir "$COREML_DIR_PATH/third-party"
2525

2626
echo "${green}ExecuTorch: Cloning coremltools."
27-
git clone --depth 1 --branch 8.0 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH
27+
git clone --depth 1 --branch 8.1 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH
2828
cd $COREMLTOOLS_DIR_PATH
2929

3030
STATUS=$?

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,23 +71,15 @@ def test_vit_skip_conv(self):
7171
)
7272
)
7373

74-
conv_block = ["aten.convolution.default", "executorch_call_delegate"]
75-
safe_softmax_block = [
76-
"getitem",
77-
"getitem",
78-
"getitem",
79-
"getitem",
80-
"aten.any.dim",
81-
"executorch_call_delegate",
82-
]
83-
final_block = ["getitem"]
84-
total = conv_block + 12 * safe_softmax_block + final_block
85-
8674
assert [
8775
node.target.__name__
8876
for node in delegated_program_manager.exported_program().graph.nodes
8977
if node.op == "call_function"
90-
] == total
78+
] == [
79+
"aten.convolution.default",
80+
"executorch_call_delegate",
81+
"getitem",
82+
]
9183

9284
def test_buffer(self):
9385
embedding_dim = 3

backends/arm/test/common.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class arm_test_options(Enum):
2929
corstone300 = auto()
3030
dump_path = auto()
3131
date_format = auto()
32+
fast_fvp = auto()
3233

3334

3435
_test_options: dict[arm_test_options, Any] = {}
@@ -41,6 +42,7 @@ def pytest_addoption(parser):
4142
parser.addoption("--arm_run_corstone300", action="store_true")
4243
parser.addoption("--default_dump_path", default=None)
4344
parser.addoption("--date_format", default="%d-%b-%H:%M:%S")
45+
parser.addoption("--fast_fvp", action="store_true")
4446

4547

4648
def pytest_configure(config):
@@ -63,6 +65,7 @@ def pytest_configure(config):
6365
f"Supplied argument 'default_dump_path={dump_path}' that does not exist or is not a directory."
6466
)
6567
_test_options[arm_test_options.date_format] = config.option.date_format
68+
_test_options[arm_test_options.fast_fvp] = config.option.fast_fvp
6669
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
6770

6871

backends/arm/test/runner_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import numpy as np
1818
import torch
1919

20+
from executorch.backends.arm.test.common import arm_test_options, is_option_enabled
21+
2022
from torch.export import ExportedProgram
2123
from torch.fx.node import Node
2224

@@ -249,6 +251,10 @@ def run_corstone(
249251
for input_path in input_paths:
250252
cmd_line += f" -i {input_path}"
251253

254+
ethos_u_extra_args = ""
255+
if is_option_enabled(arm_test_options.fast_fvp):
256+
ethos_u_extra_args = ethos_u_extra_args + "--fast"
257+
252258
command_args = {
253259
"corstone-300": [
254260
"FVP_Corstone_SSE-300_Ethos-U55",
@@ -267,6 +273,8 @@ def run_corstone(
267273
"-C",
268274
"cpu0.semihosting-stack_base=0",
269275
"-C",
276+
f"ethosu.extra_args='{ethos_u_extra_args}'",
277+
"-C",
270278
"cpu0.semihosting-heap_limit=0",
271279
"-C",
272280
f"cpu0.semihosting-cmd_line='{cmd_line}'",
@@ -282,6 +290,8 @@ def run_corstone(
282290
"-C",
283291
"mps4_board.visualisation.disable-visualisation=1",
284292
"-C",
293+
"vis_hdlcd.disable_visualisation=1",
294+
"-C",
285295
"mps4_board.telnetterminal0.start_telnet=0",
286296
"-C",
287297
"mps4_board.uart0.out_file='-'",
@@ -296,6 +306,8 @@ def run_corstone(
296306
"-C",
297307
"mps4_board.subsystem.cpu0.semihosting-heap_limit=0",
298308
"-C",
309+
f"mps4_board.subsystem.ethosu.extra_args='{ethos_u_extra_args}'",
310+
"-C",
299311
f"mps4_board.subsystem.cpu0.semihosting-cmd_line='{cmd_line}'",
300312
"-a",
301313
elf_path,

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,12 @@ def quantize_and_export_to_cadence(
235235
def export_to_executorch_gen_etrecord(
236236
model: torch.nn.Module,
237237
inputs: tuple[object, ...],
238-
dump_graphs: bool = False,
239238
output_dir: Optional[str] = None,
240239
opt_level: int = 1,
240+
dump_graphs: bool = False,
241241
) -> ExecutorchProgramManager:
242-
edge_prog_manager = export_to_edge(model, inputs)
243242
cadence_passes = get_cadence_passes(opt_level)
243+
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
244244

245245
# Run a couple required passes for quant/dequant ops
246246
cadence_prog_manager = edge_prog_manager.transform(

backends/cadence/aot/fuse_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ def fuse_quantized_batch_norm_with_conv(
426426
# Note: there is a quantized.conv2d.new operator in the resulting graph
427427
# that takes a torch.classes.quantized.Conv2dPackedParamsBase as one of the input
428428
# this prevents us to directly call graph_module.recompile().
429+
# pyre-fixme[16]: `GraphModule` has no attribute `_code`.
430+
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
431+
# `python_code`.
429432
graph_module._code = graph_module._graph.python_code(root_module="self").src
430433

431434
def __init__(self):

backends/cadence/aot/quantizer/patterns.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def partition_types(self) -> List[OpOverload]:
7575
def get_anchors(
7676
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
7777
) -> PartitionAnchors:
78+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
7879
addmm_node = fused_partition[0].nodes[-1]
7980

8081
bias_qspec = DerivedQuantizationSpec(
@@ -107,6 +108,7 @@ def partition_types(self) -> List[OpOverload]:
107108
def get_anchors(
108109
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
109110
) -> PartitionAnchors:
111+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
110112
bmm_node = fused_partition[0].nodes[-1]
111113

112114
return PartitionAnchors(
@@ -127,6 +129,7 @@ def partition_types(self) -> List[OpOverload]:
127129
def get_anchors(
128130
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
129131
) -> PartitionAnchors:
132+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
130133
conv1d_node = fused_partition[0].nodes[-1]
131134

132135
bias_qspec = DerivedQuantizationSpec(
@@ -165,6 +168,7 @@ def partition_types(self) -> List[OpOverload]:
165168
def get_anchors(
166169
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
167170
) -> PartitionAnchors:
171+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
168172
conv2d_node = fused_partition[0].nodes[-1]
169173

170174
bias_qspec = DerivedQuantizationSpec(
@@ -203,6 +207,7 @@ def partition_types(self) -> List[OpOverload]:
203207
def get_anchors(
204208
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
205209
) -> PartitionAnchors:
210+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
206211
layer_norm_node = fused_partition[0].nodes[-1]
207212

208213
others = [(layer_norm_node, 1)]
@@ -237,6 +242,7 @@ def partition_types(self) -> List[OpOverload]:
237242
def get_anchors(
238243
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
239244
) -> PartitionAnchors:
245+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
240246
linear_node = fused_partition[0].nodes[-1]
241247

242248
bias_qspec = DerivedQuantizationSpec(
@@ -275,6 +281,7 @@ def partition_types(self) -> List[OpOverload]:
275281
def get_anchors(
276282
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
277283
) -> PartitionAnchors:
284+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
278285
matmul_node = fused_partition[0].nodes[-1]
279286

280287
return PartitionAnchors(
@@ -297,6 +304,7 @@ def partition_types(self) -> List[OpOverload]:
297304
def get_anchors(
298305
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
299306
) -> PartitionAnchors:
307+
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
300308
relu_node = fused_partition[0].nodes[-1]
301309

302310
return PartitionAnchors(

backends/cadence/aot/remove_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,7 @@ class Subgraph:
561561
exir_ops.edge.aten.mul.Tensor,
562562
exir_ops.edge.aten.mean.dim,
563563
exir_ops.edge.aten.cat.default,
564+
exir_ops.edge.aten.hardtanh.default,
564565
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
565566
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
566567
}

0 commit comments

Comments
 (0)