Skip to content

Commit 5c9dbfb

Browse files
committed
Update on "Fix Cuda out of memory issue for eager runner"
This PR updates the eager runner to disable grad and save memory usage. It also update the prompt format to not include bos. Differential Revision: [D65962743](https://our.internmc.facebook.com/intern/diff/D65962743/) [ghstack-poisoned]
2 parents ebcd34b + 3debc5c commit 5c9dbfb

File tree

23 files changed

+975
-608
lines changed

23 files changed

+975
-608
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"resnet50": "linux.12xlarge",
2626
"llava": "linux.12xlarge",
2727
"llama3_2_vision_encoder": "linux.12xlarge",
28-
"llama3_2_text_decoder": "linux.12xlarge",
28+
# "llama3_2_text_decoder": "linux.12xlarge", # TODO: re-enable test when Huy's change is in / model gets smaller.
2929
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
3030
"dl3": "linux.12xlarge",
3131
"emformer_join": "linux.12xlarge",

backends/arm/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ python_library(
99
typing = True,
1010
deps = [
1111
":arm_backend",
12+
"//executorch/backends/arm/operator_support:operator_support",
1213
"//executorch/backends/arm/_passes:passes",
1314
"//executorch/exir:lib",
1415
],
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "operator_support",
5+
srcs = glob(["*.py"]),
6+
typing = True,
7+
deps = [
8+
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
9+
"//executorch/exir:lib",
10+
"//executorch/backends/arm:tosa_specification"
11+
],
12+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
import operator
9+
from typing import Type
910

1011
import torch.fx as fx
1112
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -31,7 +32,9 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
3132

3233

3334
# container for all SupportedTosaOperatorCheck classes
34-
_tosa_spec_dicts: dict[TosaSpecification, dict[str, SupportedTOSAOperatorCheck]] = {
35+
_tosa_spec_dicts: dict[
36+
TosaSpecification, dict[str, Type[SupportedTOSAOperatorCheck]]
37+
] = {
3538
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
3639
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
3740
}

0 commit comments

Comments
 (0)