Skip to content

Commit 944cccf

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK][ez] Ensure that attn_weight buffers do not exceed GPU buffer numel limit"
Title says it all! To give a concrete example, Llama3.2-1B-Instruct will have attn weights with size `{1, 32, max_seq_len, max_context_len}`. Usually `max_seq_len == max_context_len`, and if `max_context_len = 2048` Then the attention weight tensors will have sizes `{1, 32, 2048, 2048}` which will contain 134217728 elements. The `maxStorageBufferRange` for Adreno 750 is also 134217728 (2^27), so using context length of 2048 will produce incorrect results on Adreno 750. In practice, it is unlikely that the prompt sequence length will be equal to the context length, so the solution is to adjust down the `max_seq_len` dim of the attention weight tensors to ensure that the GPU buffer numel limit is not hit. Differential Revision: [D86443407](https://our.internmc.facebook.com/intern/diff/D86443407/) [ghstack-poisoned]
1 parent e33a210 commit 944cccf

File tree

1 file changed

+0
-12
lines changed

1 file changed

+0
-12
lines changed

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,14 @@
1111
from typing import Tuple
1212

1313
import executorch.backends.vulkan.test.utils as test_utils
14-
1514
import torch
16-
1715
from executorch.backends.transforms.convert_dtype_pass import I64toI32
18-
1916
from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner
20-
2117
from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend
22-
2318
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
2419
get_symmetric_quantization_config,
2520
XNNPACKQuantizer,
2621
)
27-
2822
from executorch.exir import (
2923
EdgeCompileConfig,
3024
EdgeProgramManager,
@@ -36,11 +30,8 @@
3630
)
3731
from executorch.extension.pytree import tree_flatten
3832
from torch.export import Dim, export, ExportedProgram
39-
4033
from torchao.quantization.granularity import PerGroup
41-
4234
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
43-
4435
from torchao.quantization.pt2e.quantizer import Quantizer
4536
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
4637
from torchao.utils import unwrap_tensor_subclass
@@ -69,9 +60,6 @@ def lower_module(
6960
edge_program = to_edge_transform_and_lower(
7061
program,
7162
compile_config=edge_compile_config,
72-
transform_passes=[
73-
I64toI32(edge_compile_config._skip_dim_order),
74-
],
7563
partitioner=[VulkanPartitioner(compile_options)],
7664
)
7765

0 commit comments

Comments
 (0)