Skip to content

Commit 38384a2

Browse files
authored
Add support for splitting in_features in linear layers (#8715)
init
1 parent 7ce47fc commit 38384a2

File tree

5 files changed

+190
-58
lines changed

5 files changed

+190
-58
lines changed

examples/apple/coreml/llama/export.py

Lines changed: 14 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2-
3-
# pyre-strict
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
46

57
import argparse
68

@@ -24,55 +26,7 @@
2426

2527
sys.path.insert(0, ".")
2628
from llama_transformer import InputManager, load_model
27-
28-
29-
class SplitLinearModule(torch.nn.Module):
30-
def __init__(self, in_features, out_features, target_split_size, max_splits):
31-
super(SplitLinearModule, self).__init__()
32-
num_splits = max(out_features // target_split_size, 1)
33-
if num_splits > max_splits:
34-
num_splits = max_splits
35-
36-
self.split_size = out_features // num_splits
37-
self.split_remainder = out_features % num_splits
38-
self.splits = torch.nn.ModuleList(
39-
[torch.nn.Linear(in_features, self.split_size) for _ in range(num_splits)]
40-
)
41-
print(
42-
f"Splitting out_features={out_features} into {num_splits} of size {self.split_size}"
43-
)
44-
if self.split_remainder > 0:
45-
print(
46-
f"Warning: remainder {self.split_remainder} after splitting out_features={out_features} into {num_splits} of size {self.split_size}"
47-
)
48-
self.splits.append(torch.nn.Linear(in_features, self.split_remainder))
49-
50-
def split_sizes(self):
51-
return [split.out_features for split in self.splits]
52-
53-
def forward(self, x):
54-
return torch.cat([split(x) for split in self.splits], dim=-1)
55-
56-
57-
def replace_linear_with_split_linear(model, target_split_size, max_splits):
58-
for name, module in model.named_children():
59-
if isinstance(module, torch.nn.Linear):
60-
new_module = SplitLinearModule(
61-
module.in_features, module.out_features, target_split_size, max_splits
62-
)
63-
split_sizes = new_module.split_sizes()
64-
if module.bias is not None:
65-
split_bias = module.bias.split(split_sizes)
66-
split_weights = module.weight.split(split_sizes, dim=0)
67-
for i, split in enumerate(new_module.splits):
68-
split.weight = torch.nn.Parameter(split_weights[i])
69-
if module.bias is not None:
70-
split.bias = torch.nn.Parameter(split_bias[i])
71-
else:
72-
split.bias = None
73-
setattr(model, name, new_module)
74-
else:
75-
replace_linear_with_split_linear(module, target_split_size, max_splits)
29+
from utils import replace_linear_with_split_linear
7630

7731

7832
def main() -> None:
@@ -175,7 +129,13 @@ def main() -> None:
175129

176130
if export_args.target_split_size is not None:
177131
replace_linear_with_split_linear(
178-
model, export_args.target_split_size, export_args.max_splits
132+
model,
133+
out_target_split_size=export_args.target_split_size,
134+
out_max_splits=export_args.max_splits,
135+
# I have not found splitting on in_features to be beneficial,
136+
# and it often leads to OOM so I set in_max_splits to 1
137+
in_target_split_size=1,
138+
in_max_splits=1,
179139
)
180140

181141
model.eval()
@@ -241,6 +201,7 @@ def main() -> None:
241201
ep,
242202
preserve_ops=[
243203
torch.ops.aten.scaled_dot_product_attention.default,
204+
# preserve norm op for numerical stability
244205
torch.ops.aten.linalg_vector_norm.default,
245206
],
246207
compile_config=EdgeCompileConfig(

examples/apple/coreml/llama/readme.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ The runner can also be used to run an eager model model to compare with CoreML n
3838

3939
We are actively experimenting with different settings. But here are ones that we've found work well for Llama1B on iPhone 15 Pro:
4040

41-
* Set use_cache_list
42-
* Split linear layers with target_split_size=1024, max_splits=8
43-
* Use seq_length=32 or seq_length=64, both of which offer reasonable tradeoffs for prefill and decode performance. seq_length=32 is better at decode and seq_length=64 is better at prefill.
44-
45-
In our tests, we set max_seq_length=1024, but if your application allows for it, performance can improve with max_seq_length=512 or by keeping max_seq_length=1024 and setting cache_size=512-seq_length.
41+
* Set use_cache_list.
42+
* Use seq_length = 32, which offers a good balance between prefill/decode performance.
43+
* Split out_features in linear layers with target_split_size=1024, max_splits=8.
44+
* For ANE, set dtype = fp16, coreml-quantize = c4w. The requires doing QAT on Llama1B for good accuracy.
45+
* Set embedding-quantize to "4,32".
46+
* Set max_seq_length to 128, 256, 512, 1024, and 2048, depending on needed context. Note that performance drops with max_seq_length. More specifically, performance drops with cache_size, and the best experience may require a good cache eviction policy. The python runner in run.py uses a last-in-last-out policy when cache_size is specified.

examples/apple/coreml/llama/run.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import argparse
28
import sys
39

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
9+
sys.path.insert(0, ".")
10+
import copy
11+
12+
import torch
13+
from utils import replace_linear_with_split_linear
14+
15+
16+
def get_split_model(
17+
model,
18+
out_target_split_size=1,
19+
out_max_splits=1,
20+
in_target_split_size=1,
21+
in_max_splits=1,
22+
):
23+
model_copy = copy.deepcopy(model)
24+
replace_linear_with_split_linear(
25+
model_copy,
26+
out_target_split_size,
27+
out_max_splits,
28+
in_target_split_size,
29+
in_max_splits,
30+
)
31+
return model_copy
32+
33+
34+
def test_split_model():
35+
inputs = torch.randn(10, 5, 1, 512)
36+
37+
model = torch.nn.Sequential(*[torch.nn.Linear(512, 1024, bias=False)])
38+
model1 = get_split_model(model, 64, 2, 64, 1000)
39+
model2 = get_split_model(model, 64, 2, 64, 1)
40+
model3 = get_split_model(model, 64, 1, 64, 1000)
41+
42+
assert torch.allclose(model(inputs), model1(inputs), atol=1e-5)
43+
assert torch.allclose(model(inputs), model2(inputs), atol=1e-5)
44+
assert torch.allclose(model(inputs), model3(inputs), atol=1e-5)
45+
46+
47+
if __name__ == "__main__":
48+
test_split_model()
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
10+
class SplitLinearModule(torch.nn.Module):
11+
def __init__(
12+
self,
13+
in_features,
14+
out_features,
15+
out_target_split_size=1,
16+
out_max_splits=1,
17+
in_target_split_size=1,
18+
in_max_splits=1,
19+
):
20+
super(SplitLinearModule, self).__init__()
21+
self.out_split_sizes = self._get_split_sizes(
22+
out_features, out_target_split_size, out_max_splits
23+
)
24+
self.in_split_sizes = self._get_split_sizes(
25+
in_features, in_target_split_size, in_max_splits
26+
)
27+
print(
28+
f"Splitting out_features={out_features} into {len(self.out_split_sizes)} of size {self.out_split_sizes[0]}."
29+
)
30+
print(
31+
f"Splitting in_features={in_features} into {len(self.in_split_sizes)} of size {self.in_split_sizes[0]}."
32+
)
33+
34+
# self.ops contains a list of linear ops for different pieces of the output matrix
35+
# The index of an op at (in_idx, out_idx) is given by self.op_index(in_idx, out_idx)
36+
self.ops = torch.nn.ModuleList()
37+
for idx_out, s_out in enumerate(self.out_split_sizes):
38+
for idx_in, s_in in enumerate(self.in_split_sizes):
39+
assert len(self.ops) == self.op_index(idx_in, idx_out)
40+
self.ops.append(torch.nn.Linear(s_in, s_out, bias=False))
41+
42+
def op_index(self, in_index, out_index):
43+
idx = out_index * len(self.in_split_sizes) + in_index
44+
return idx
45+
46+
def _get_split_sizes(self, n_features, target_split_size, max_splits):
47+
num_splits = max(n_features // target_split_size, 1)
48+
if num_splits > max_splits:
49+
num_splits = max_splits
50+
51+
split_size = n_features // num_splits
52+
split_remainder = n_features % num_splits
53+
if split_remainder > 0:
54+
raise ValueError(
55+
f"Cannot split {n_features} with target_split_size={target_split_size} and max_splits={max_splits} because it leaves a remainder of {split_remainder}."
56+
)
57+
58+
ret = [split_size for _ in range(num_splits)]
59+
return ret
60+
61+
def set_params(self, weight):
62+
split_weights = []
63+
for w_out in weight.split(self.out_split_sizes, dim=0):
64+
for w in w_out.split(self.in_split_sizes, dim=1):
65+
split_weights.append(w)
66+
67+
for i, split in enumerate(self.ops):
68+
split.weight = torch.nn.Parameter(split_weights[i])
69+
70+
def forward(self, x):
71+
if len(self.in_split_sizes) == 1:
72+
out_chunks = [op(x) for op in self.ops]
73+
else:
74+
x_splits = x.split(self.in_split_sizes, dim=-1)
75+
out_chunks = [
76+
torch.sum(
77+
torch.stack(
78+
[
79+
self.ops[self.op_index(in_idx, out_idx)].forward(
80+
x_splits[in_idx]
81+
)
82+
for in_idx in range(len(self.in_split_sizes))
83+
],
84+
),
85+
dim=0,
86+
)
87+
for out_idx in range(len(self.out_split_sizes))
88+
]
89+
90+
return torch.concat(out_chunks, dim=-1)
91+
92+
93+
def replace_linear_with_split_linear(
94+
model, out_target_split_size, out_max_splits, in_target_split_size, in_max_splits=1
95+
):
96+
for name, module in model.named_children():
97+
if isinstance(module, torch.nn.Linear):
98+
assert module.bias is None, "SplitLinearModule does not support bias"
99+
new_module = SplitLinearModule(
100+
module.in_features,
101+
module.out_features,
102+
out_target_split_size,
103+
out_max_splits,
104+
in_target_split_size,
105+
in_max_splits,
106+
)
107+
new_module.set_params(module.weight)
108+
setattr(model, name, new_module)
109+
else:
110+
replace_linear_with_split_linear(
111+
module,
112+
out_target_split_size,
113+
out_max_splits,
114+
in_target_split_size,
115+
in_max_splits,
116+
)

0 commit comments

Comments
 (0)