Skip to content

Commit 0fc7383

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Return None as dynamic shape when enable_dynamic_shape is False
Summary: In `LLMEdgeManager`, `self.enable_dynamic_shape = False` means the token dimension is static (always 1). Differential Revision: D72805966
1 parent 060cda3 commit 0fc7383

File tree

5 files changed

+141
-7
lines changed

5 files changed

+141
-7
lines changed

extension/llm/export/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ runtime.python_library(
2222
"//bento/...",
2323
"//bento_kernels/...",
2424
"//executorch/examples/...",
25+
"//executorch/extension/llm/...",
2526
"//meta_intern_odllm/...",
2627
],
2728
deps = [

extension/llm/export/builder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def _get_dynamic_shape(self) -> Any:
178178
return self.dynamic_shapes
179179

180180
dim = torch.export.Dim("token_dim", max=self.max_seq_len - 1)
181-
182-
if not self.use_kv_cache:
183-
# Only one input argument: tokens
184-
self.dynamic_shapes = ({1: dim},)
185-
elif self.enable_dynamic_shape:
186-
# Two input arguments: tokens and input_pos but input_pos is static shape
187-
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
181+
if self.enable_dynamic_shape:
182+
if not self.use_kv_cache:
183+
# Only one input argument: tokens
184+
self.dynamic_shapes = ({1: dim},)
185+
else:
186+
# Two input arguments: tokens and input_pos but input_pos is static shape
187+
self.dynamic_shapes = ({1: dim}, {"input_pos": {0: 1}})
188188
else:
189189
# Two input arguments: tokens and input_pos but both are of static shape
190190
self.dynamic_shapes = None

extension/llm/export/test/TARGETS

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
8+
9+
oncall("executorch")
10+
11+
runtime.python_test(
12+
name = "test_builder",
13+
srcs = ["test_builder.py"],
14+
deps = [
15+
"//executorch/extension/llm/export:export_lib",
16+
"//caffe2:torch",
17+
],
18+
)

extension/llm/export/test/__init__.py

Whitespace-only changes.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from unittest.mock import MagicMock, patch
9+
import torch
10+
11+
from executorch.extension.llm.export.builder import LLMEdgeManager, DType
12+
13+
14+
class TestLLMEdgeManager(unittest.TestCase):
15+
def setUp(self):
16+
# Create a mock model
17+
self.mock_model = MagicMock()
18+
self.modelname = "test_model"
19+
self.max_seq_len = 2048
20+
self.dtype = DType.fp32
21+
self.example_inputs = (torch.zeros((1, 10), dtype=torch.long),)
22+
self.example_kwarg_inputs = {"input_pos": torch.tensor([0])}
23+
24+
def test_get_dynamic_shape_with_preset_dynamic_shapes(self):
25+
"""Test that _get_dynamic_shape returns preset dynamic_shapes if available."""
26+
# Create a manager with preset dynamic_shapes
27+
preset_dynamic_shapes = {"preset": "shapes"}
28+
manager = LLMEdgeManager(
29+
model=self.mock_model,
30+
modelname=self.modelname,
31+
max_seq_len=self.max_seq_len,
32+
dtype=self.dtype,
33+
use_kv_cache=False,
34+
example_inputs=self.example_inputs,
35+
dynamic_shapes=preset_dynamic_shapes,
36+
)
37+
38+
# Call _get_dynamic_shape and verify it returns the preset value
39+
result = manager._get_dynamic_shape()
40+
self.assertEqual(result, preset_dynamic_shapes)
41+
42+
def test_get_dynamic_shape_with_dynamic_shape_enabled_no_kv_cache(self):
43+
"""Test _get_dynamic_shape when enable_dynamic_shape=True and use_kv_cache=False."""
44+
# Create a manager with enable_dynamic_shape=True and use_kv_cache=False
45+
manager = LLMEdgeManager(
46+
model=self.mock_model,
47+
modelname=self.modelname,
48+
max_seq_len=self.max_seq_len,
49+
dtype=self.dtype,
50+
use_kv_cache=False,
51+
example_inputs=self.example_inputs,
52+
enable_dynamic_shape=True,
53+
)
54+
55+
# Call _get_dynamic_shape
56+
result = manager._get_dynamic_shape()
57+
58+
# Verify the result has the expected structure
59+
self.assertIsInstance(result, tuple)
60+
self.assertEqual(len(result), 1)
61+
self.assertIsInstance(result[0], dict)
62+
self.assertIn(1, result[0])
63+
# Check that the value at key 1 is a torch.export.Dim with the correct max value
64+
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
65+
66+
def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self):
67+
"""Test _get_dynamic_shape when enable_dynamic_shape=True and use_kv_cache=True."""
68+
# Create a manager with enable_dynamic_shape=True and use_kv_cache=True
69+
manager = LLMEdgeManager(
70+
model=self.mock_model,
71+
modelname=self.modelname,
72+
max_seq_len=self.max_seq_len,
73+
dtype=self.dtype,
74+
use_kv_cache=True,
75+
example_inputs=self.example_inputs,
76+
enable_dynamic_shape=True,
77+
)
78+
79+
# Call _get_dynamic_shape
80+
result = manager._get_dynamic_shape()
81+
82+
# Verify the result has the expected structure
83+
self.assertIsInstance(result, tuple)
84+
self.assertEqual(len(result), 2)
85+
86+
# Check first element (tokens dimension)
87+
self.assertIsInstance(result[0], dict)
88+
self.assertIn(1, result[0])
89+
self.assertEqual(result[0][1].max, self.max_seq_len - 1)
90+
91+
# Check second element (input_pos dimension)
92+
self.assertIsInstance(result[1], dict)
93+
self.assertIn("input_pos", result[1])
94+
self.assertIsInstance(result[1]["input_pos"], dict)
95+
self.assertIn(0, result[1]["input_pos"])
96+
self.assertEqual(result[1]["input_pos"][0], 1)
97+
98+
def test_get_dynamic_shape_with_dynamic_shape_disabled(self):
99+
"""Test _get_dynamic_shape when enable_dynamic_shape=False."""
100+
# Create a manager with enable_dynamic_shape=False
101+
manager = LLMEdgeManager(
102+
model=self.mock_model,
103+
modelname=self.modelname,
104+
max_seq_len=self.max_seq_len,
105+
dtype=self.dtype,
106+
use_kv_cache=True, # Doesn't matter for this test
107+
example_inputs=self.example_inputs,
108+
enable_dynamic_shape=False,
109+
)
110+
111+
# Call _get_dynamic_shape
112+
result = manager._get_dynamic_shape()
113+
114+
# Verify the result is None
115+
self.assertIsNone(result)

0 commit comments

Comments
 (0)