Skip to content

Commit 0392ea4

Browse files
author
ssjia
committed
Update on "[ET-VK] Allocate memory for weight and activation tensors lazily"
Summary: * Allocate memory for weight tensors right before the prepacking shader is dispatched, rather than while building the graph * Move allocation of shared objects (i.e. memory for intermediate tensors) to occur after prepacking ## Motivation Prevent screen blackout (Llama 3.2 1B) / device crash (Llama 3.2 3B) when running Llama 3.2 models on Samsung Galaxy S24. This behaviour is related to high peak memory usage when loading the model. ## Full Context During model loading, Vulkan delegate needs to store 3 copies of constant data in memory at various points: * source data obtained from loading the model * staging buffer * GPU texture/buffer The general rationale of this change is to allocate memory for each copy only when necessary to minimize the "overlap" when all 3 exist at once. ### Current Order of operations Legend: * `W` represents total weight nbytes * `w` represents weight nbytes for one tensor * `A` represents total activations nbytes * `M` represents approximation of total memory footprint First, model file is loaded Then, when building compute graph, for each weight tensor: 1. Weight data is loaded from NamedDataMap (`M = W`) 2. GPU texture/buffer for weight is initialized + memory allocated (`M = 2W`) 3. After building the graph, `graph->prepare()` is called which currently allocates memory for the activation tensors as well (`M = 2W + A`) Then, during the prepacking stage for each weight tensor, each weight tensor is copied individually: 1. Staging buffer initialized (`M = 2W + A + w`) 2. Copy CPU weight data to staging + CPU Weight data is freed (`M = 2W + A`) 3. Compute shader dispatch to copy staging to GPU texture/buffer + free staging buffer (`M = 2W + A - w`) The peak usage in mainline will be `M = 2W + A + w` ### Revised order of operations This change revises the order of operations: 1. Weight data is loaded from NamedDataMap (`M = W`) 2. GPU texture/buffer for weight is initialized, but **memory is not allocated** (`M = W`) Then, during the prepacking stage for each weight tensor, each weight tensor is copied individually: 1. Staging buffer initialized (`M = W + w`) 2. **Memory allocated for GPU texture/buffer** (`M = W + 2w`) 3. Copy CPU weight data to staging + CPU Weight data is freed (`M = W + w`) 4. Compute shader dispatch to copy staging to GPU texture/buffer + free staging buffer (`M = W`) **Then, after all prepacking operations complete, only then is Activation memory allocated** (`M = W + A`) Under this scheme, peak memory is reduced to `M = W + A` (or alternatively `M = W + 2w` if `2w > A`) which is (or at least very close to) the theoretical minimum. Test Plan: ## Logging Memory Usage Using ``` uint64_t getVmRssInKB() { std::ifstream statusFile("/proc/self/status"); std::string l, num; while (std::getline(statusFile, l)) { if (l.substr(0, 5) == "VmRSS") { size_t pos = l.find_first_of("0123456789"); num = l.substr(pos); break; } } uint64_t vmRssInKB = std::stoi(num); return vmRssInKB; } uint64_t getVmaStatsInKB() { auto stats = vkcompute::api::context()->adapter_ptr()->vma().get_memory_statistics(); uint64_t vmaBlockInKB = stats.total.statistics.blockBytes >> 10; return vmaBlockInKB; } ``` to log memory footprint at various points of inference when running the llama_runner binary with Llama 3.2 1B, we can compare the memory footprint with and without these changes. With changes: P1908051860 (Meta only) ``` Memory usage before model compilation: 1115760 KB (VmRSS), 0 KB (VMA) Memory usage after graph building: 1924832 KB (VmRSS), 17920 KB (VMA) Memory usage after graph preparation: 1935312 KB (VmRSS), 17920 KB (VMA) Memory usage prepack start: 1935312 KB, VMA Block: 17920 KB Memory usage after prepack operations: 1372376 KB (VmRSS), 2330528 KB (VMA) Memory usage before execute: 1372804 KB (VmRSS), 2330528 KB (VMA) Memory usage at end of execute: 1376916 KB (VmRSS), 2330528 KB (VMA) ``` WIthout changes: P1908054759 (Meta only) ``` Memory usage before model compilation: 1114784 KB (VmRSS), 0 KB (VMA) Memory usage after graph building: 1924432 KB (VmRSS), 962464 KB (VMA) Memory usage after graph preparation: 1922916 KB (VmRSS), 2326432 KB (VMA) Memory usage prepack start: 1922916 KB, VMA Block: 2326432 KB Memory usage after prepack operations: 1359180 KB (VmRSS), 2330528 KB (VMA) Memory usage before execute: 1359492 KB (VmRSS), 2330528 KB (VMA) Memory usage at end of execute: 1363636 KB (VmRSS), 2330528 KB (VMA) ``` It is evident how peak memory can be reduced with these changes, as VMA footprint gradually increases while loading the model while VmRss gradually decreases. Without these changes, VMA footprint will reach its peak after initializing the graph. Visually, it can also be verified that Samsung Galaxy S24's screen no longer blacks out while loading the model. Differential Revision: [D80460033](https://our.internmc.facebook.com/intern/diff/D80460033) [ghstack-poisoned]
2 parents eaa165b + cf39517 commit 0392ea4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+727
-1435
lines changed

backends/arm/_passes/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
3434
from .decompose_cosh_pass import DecomposeCoshPass # noqa
3535
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
36-
from .decompose_cumsum_pass import DecomposeCumsumPass # noqa
3736
from .decompose_div_pass import DecomposeDivPass # noqa
3837
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
3938
from .decompose_expm1_pass import DecomposeExpm1Pass # noqa

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,36 @@
1414
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.pass_base import ExportPass, PassResult
17+
from torch.library import impl, Library
18+
19+
# Define lib with passthrough operators. The operators have no real meaning in edge IR
20+
# except for argument validaiton and a passthrough output. The operators will be used
21+
# when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
22+
# the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
23+
lib = Library("passthrough_to_tosa", "DEF")
24+
# For certain operators we need the data in a specific data format. Changing tosa_dim_order
25+
# is not sufficient as we also need transpose the data.
26+
# By utilizing an edge IR passthrough operator we can keep the edge program in
27+
# channels-first/contiguous and get the desired behavior in the TOSA lowering.
28+
lib.define("_transpose(Tensor self, int[] dim_order) -> Tensor")
29+
30+
31+
@impl(lib, "_transpose")
32+
def _transpose_impl(*args, **kwargs):
33+
# Validate length of dim_order array
34+
dim = args[1]
35+
if len(dim) != 4 and len(dim) != 5:
36+
raise ValueError(
37+
f"Dim order length must be either 4 or 5, got {len(dim)}: {dim}"
38+
)
39+
# Pass-through in edge-IR
40+
return args[0]
1741

1842

1943
class AnnotateChannelsLastDimOrder(ExportPass):
2044
"""
2145
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
22-
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
46+
that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
2347
when a transition between 3D and 4D/5D tensors happen.
2448
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
2549
"""
@@ -95,7 +119,7 @@ def insert_input_transpose(node, input_node, graph_module):
95119
with graph_module.graph.inserting_before(node):
96120
permute_node = create_node(
97121
graph_module.graph,
98-
exir_ops.backend.tosa.TRANSPOSE.default,
122+
torch.ops.passthrough_to_tosa._transpose.default,
99123
args=(
100124
input_node,
101125
list(
@@ -117,7 +141,7 @@ def insert_output_transpose(node, graph_module):
117141
with graph_module.graph.inserting_after(node):
118142
permute_node = create_node(
119143
graph_module.graph,
120-
exir_ops.backend.tosa.TRANSPOSE.default,
144+
torch.ops.passthrough_to_tosa._transpose.default,
121145
args=(
122146
node,
123147
list(

backends/arm/_passes/arm_pass_manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
DecomposeBatchNormNoStatsPass,
3939
DecomposeCoshPass,
4040
DecomposeCosineSimilarityPass,
41-
DecomposeCumsumPass,
4241
DecomposeDivPass,
4342
DecomposeEmbeddingPass,
4443
DecomposeExpm1Pass,
@@ -149,7 +148,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
149148
self.add_pass(UnsqueezeBeforeRepeatPass())
150149
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
151150
self.add_pass(DecomposeSumPass())
152-
self.add_pass(DecomposeCumsumPass(exported_program))
153151
self.add_pass(Conv1dUnsqueezePass())
154152
self.add_pass(DecomposeMaxPool2DPass())
155153
self.add_pass(SizeAdjustInputPass())
@@ -229,7 +227,6 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
229227
self.add_pass(UnsqueezeBeforeRepeatPass())
230228
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
231229
self.add_pass(DecomposeSumPass())
232-
self.add_pass(DecomposeCumsumPass(exported_program))
233230
self.add_pass(Conv1dUnsqueezePass())
234231
self.add_pass(DecomposeMaxPool2DPass())
235232
self.add_pass(SizeAdjustInputPass())

backends/arm/_passes/decompose_cumsum_pass.py

Lines changed: 0 additions & 142 deletions
This file was deleted.

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def call(self, graph_module):
107107
for node in graph_module.graph.nodes:
108108
if node.op != "call_function":
109109
continue
110-
if node.target == exir_ops.backend.tosa.TABLE.default:
110+
if node.target == torch.ops.tosa._table.default:
111111
continue
112112

113113
input_nodes = node.all_input_nodes

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,70 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
67
from copy import copy
78
from typing import cast
89

10+
import torch
911
from executorch.backends.arm._passes.arm_pass_utils import create_node
1012
from executorch.backends.arm._passes.quant_args import QuantArgs
1113
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
12-
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
15+
from torch import Tensor
1416
from torch.fx import GraphModule, Node
17+
from torch.library import custom_op, register_fake
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
@custom_op("tosa::_rescale", mutates_args=()) # type: ignore[misc]
23+
def rescale(
24+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
25+
) -> Tensor:
26+
logger.warning(
27+
"Ran default implementation of tosa::_rescale."
28+
"This op is meant to always be inserted inside a partition and a correct default implementation is not implemented."
29+
)
30+
# Clone is needed to not return reference when rescaling to same dtype.
31+
# This is a neccessary requirement for non-mutating custom ops.
32+
return x.to(dtype=dtype).clone()
33+
34+
35+
@register_fake("tosa::_rescale") # type: ignore[misc]
36+
def rescale_fake(
37+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
38+
) -> Tensor:
39+
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
40+
Additionally validates TOSA constraints of a RESCALE op.
41+
"""
42+
if dtype not in (torch.int32, torch.int8, torch.int16):
43+
raise NotImplementedError(
44+
f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}"
45+
)
46+
if dtype in (torch.int32, torch.int16) and out_zp != 0:
47+
raise ValueError(
48+
f"TOSA requires output_zp to be zero when the output dtype is {dtype}."
49+
)
50+
if x.dtype in (torch.int32, torch.int16) and in_zp != 0:
51+
raise ValueError(
52+
f"TOSA requires input_zp to be zero when the input dtype is {dtype}"
53+
)
54+
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
55+
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")
56+
if dtype == torch.int8 and not -128 <= out_zp <= 127:
57+
raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.")
58+
59+
return x.to(dtype=dtype).clone()
1560

1661

1762
class InsertRescalePass(ExportPass):
1863
"""Finds patterns of dq -> q, and replaces them
19-
with backend dialect tosa::RESCALE op.
64+
with passthrough_to_tosa::rescales.
2065
21-
Does not guarantee that the dtypes and zero points are valid
66+
Does not garantuee that the dtypes and zero points are valid
2267
in TOSA, that is the job of the quantization annotator that
2368
produced the dq and q nodes. The TOSA constraints are validated
24-
in the fake implementation of.
69+
in the fake implementation of passthrough_to_tosa:rescale.
2570
"""
2671

2772
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
@@ -32,7 +77,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule
3277
with graph_module.graph.inserting_before(node):
3378
rescale_node = create_node(
3479
graph_module.graph,
35-
exir_ops.backend.tosa.RESCALE.default,
80+
torch.ops.tosa._rescale.default,
3681
(
3782
node.all_input_nodes[0],
3883
q_args.dtype,

0 commit comments

Comments
 (0)