Skip to content

Commit 30b31ac

Browse files
committed
optimize scale_factor and feedforward and change int32 for torken to improve embedding op
1 parent 37c471d commit 30b31ac

File tree

7 files changed

+62
-14
lines changed

7 files changed

+62
-14
lines changed

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ def _cast_to_int32(self, graph_module: torch.fx.GraphModule):
6161
to_dst_node.meta["val"] = node_val.to(torch.int32)
6262

6363
# Replace usage of the src dtype result with the dst dtype result.
64-
if n.name != "tokens":
65-
n.replace_all_uses_with(to_dst_node)
66-
else:
67-
for user in n.users.copy():
68-
if user.name != "quantized_decomposed_embedding_4bit_dtype":
69-
user.replace_input_with(n, to_dst_node)
64+
n.replace_all_uses_with(to_dst_node)
7065
to_dst_node.args = (n,)
7166

7267
def call(self, graph_module: torch.fx.GraphModule):

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class LayoutTransform(ExportPass):
6262
exir_ops.edge.aten.prelu.default,
6363
exir_ops.edge.aten.relu.default,
6464
exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis.
65+
exir_ops.edge.aten.sigmoid.default,
6566
exir_ops.edge.aten.sqrt.default,
6667
exir_ops.edge.aten.sub.Tensor,
6768
exir_ops.edge.aten.sum.dim_IntList,

backends/qualcomm/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ def __init__(self, weight, bias=None):
166166
super().__init__()
167167
use_bias = bias is not None
168168
self.conv = torch.nn.Conv2d(
169-
in_channels=weight.shape[0],
170-
out_channels=weight.shape[1],
169+
in_channels=weight.shape[1],
170+
out_channels=weight.shape[0],
171171
kernel_size=1,
172172
padding=0,
173173
bias=use_bias,

examples/models/llama/export_llama_lib.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
replace_causal_mask,
6767
replace_kv_cache_with_coreml_kv_cache,
6868
replace_kv_cache_with_simple_kv_cache,
69+
replace_feedforward_to_conv2d,
6970
replace_sdpa_with_coreml_sdpa,
7071
replace_sdpa_with_custom_op,
7172
replace_sdpa_with_flex_sdpa,
@@ -961,6 +962,7 @@ def _get_source_transforms( # noqa
961962
transforms.append(replace_attention_to_attention_sha)
962963
transforms.append(replace_causal_mask)
963964
transforms.append(replace_rms_norm_with_native_rms_norm)
965+
transforms.append(replace_feedforward_to_conv2d)
964966
transforms.append(convert_linear_to_conv2d)
965967
else:
966968
transforms.append(replace_kv_cache_with_simple_kv_cache)
@@ -972,6 +974,7 @@ def _get_source_transforms( # noqa
972974
transforms.append(
973975
get_model_with_r1_r2(args.optimized_rotation_path)
974976
)
977+
transforms.append(replace_feedforward_to_conv2d)
975978
transforms.append(convert_linear_to_conv2d)
976979

977980
elif args.mps:

examples/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def get_example_inputs_kvcache_sdpa(self):
245245
else:
246246
return (
247247
torch.tensor(
248-
[[1]], dtype=torch.long
248+
[[1]], dtype=torch.int32
249249
), # tokens, with kv cache our input token length is always just 1 token.
250250
torch.tensor(
251251
[0], dtype=torch.long

examples/models/llama/source_transformation/sdpa.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from typing import Tuple, Union
1313

1414
import torch
15+
import torch.nn.functional as F
1516

16-
from executorch.examples.models.llama.llama_transformer import KVCache, SDPA
17+
from executorch.examples.models.llama.llama_transformer import KVCache, SDPA, FeedForward
1718
from executorch.examples.models.llama.source_transformation.quantized_kv_cache import (
1819
QuantizedKVCache,
1920
)
@@ -171,12 +172,14 @@ def __init__(
171172
self,
172173
kv_cache: KVCache,
173174
dim: int,
175+
head_dim: int,
174176
n_rep: int,
175177
):
176178
super().__init__()
177179
self.kv_cache = kv_cache
178180
self.dim = dim
179181
self.n_rep = n_rep
182+
self.scale_factor = math.sqrt(head_dim)
180183

181184
def forward(
182185
self,
@@ -195,8 +198,7 @@ def forward(
195198
v = repeat_kv(v, self.n_rep)
196199
attn_mask = mask[input_pos]
197200

198-
scale_factor = 1 / math.sqrt(q.size(-1))
199-
attn_weight = q @ k.transpose(-2, -1) * scale_factor
201+
attn_weight = q @ k.transpose(-2, -1) / self.scale_factor
200202
attn_weight += attn_mask
201203
attn_weight = torch.softmax(attn_weight, dim=-1)
202204
y = attn_weight @ v
@@ -223,7 +225,7 @@ def replace_sdpa_with_flex_sdpa(module: torch.nn.Module):
223225
setattr(
224226
module,
225227
name,
226-
SDPAFlex(child.kv_cache, child.dim, child.n_rep),
228+
SDPAFlex(child.kv_cache, child.dim, child.head_dim, child.n_rep),
227229
)
228230
else:
229231
replace_sdpa_with_flex_sdpa(child)
@@ -428,3 +430,50 @@ def replace_causal_mask(module: torch.nn.Module):
428430
for _, child in module.named_children():
429431
replace_causal_mask(child)
430432
return module
433+
434+
class FeedForwardConv2D(torch.nn.Module):
435+
def __init__(self, w1: torch.nn.Linear, w2: torch.nn.Linear, w3: torch.nn.Linear):
436+
super().__init__()
437+
self.w1_conv = torch.nn.Conv2d(
438+
in_channels=w1.weight.shape[1],
439+
out_channels=w1.weight.shape[0],
440+
kernel_size=1,
441+
padding=0,
442+
bias=False,
443+
)
444+
self.w2_conv = torch.nn.Conv2d(
445+
in_channels=w2.weight.shape[1],
446+
out_channels=w2.weight.shape[0],
447+
kernel_size=1,
448+
padding=0,
449+
bias=False,
450+
)
451+
self.w3_conv = torch.nn.Conv2d(
452+
in_channels=w3.weight.shape[1],
453+
out_channels=w3.weight.shape[0],
454+
kernel_size=1,
455+
padding=0,
456+
bias=False,
457+
)
458+
459+
self.w1_conv.weight = torch.nn.Parameter(w1.weight.reshape(*w1.weight.shape, 1, 1))
460+
self.w2_conv.weight = torch.nn.Parameter(w2.weight.reshape(*w2.weight.shape, 1, 1))
461+
self.w3_conv.weight = torch.nn.Parameter(w3.weight.reshape(*w3.weight.shape, 1, 1))
462+
463+
464+
def forward(self, x):
465+
rank = x.dim()
466+
x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1)
467+
x = torch.transpose(x, 1, 2)
468+
res = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x))
469+
res = torch.transpose(res, 1, 2)
470+
res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3])
471+
return res
472+
473+
def replace_feedforward_to_conv2d(module: torch.nn.Module):
474+
for name, child in module.named_children():
475+
if isinstance(child, FeedForward):
476+
setattr(module, name, FeedForwardConv2D(child.w1, child.w2, child.w3))
477+
else:
478+
replace_feedforward_to_conv2d(child)
479+
return module

extension/llm/export/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def calibrate_template(
237237
with torch.no_grad():
238238
while token_list[-1] != tokenizer.eos_id and pos < max_len:
239239
logits = module(
240-
torch.full((1, 1), token_list[pos]),
240+
torch.full((1, 1), token_list[pos], dtype=torch.int32),
241241
torch.tensor((pos,)),
242242
)
243243
pos += 1

0 commit comments

Comments
 (0)