Skip to content

Commit a255e2f

Browse files
author
Sanggyu Lee
committed
Add wq,wk,wv,wo and remove_unused_input pass
1 parent 16a07af commit a255e2f

File tree

5 files changed

+86
-4
lines changed

5 files changed

+86
-4
lines changed

test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/layer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def populate_args(args_dict, filter):
8585
@torch.library.impl("circle::attention.llama", "CPU")
8686
def attention_llama_cpu(
8787
hidden_states,
88+
q_proj,
89+
k_proj,
90+
v_proj,
91+
o_proj,
8892
position_cos,
8993
position_sin,
9094
attention_mask,
@@ -100,6 +104,10 @@ def attention_llama_cpu(
100104
def attention_llama(*args, **kwargs):
101105
(
102106
hidden_states,
107+
q_proj,
108+
k_proj,
109+
v_proj,
110+
o_proj,
103111
position_cos,
104112
position_sin,
105113
attention_mask,
@@ -131,6 +139,10 @@ def forward_adapter(
131139
return (
132140
torch.ops.circle.attention.llama(
133141
hidden_states,
142+
self.q_proj.weight,
143+
self.k_proj.weight,
144+
self.v_proj.weight,
145+
self.o_proj.weight,
134146
position_embeddings[0], # cos
135147
position_embeddings[1], # sin
136148
attention_mask,
@@ -155,4 +167,4 @@ def forward_adapter(
155167
model = AutoModelForCausalLM.from_pretrained(model_name)
156168
model.eval()
157169
circle_model = tico.convert(model.model.layers[0], captured_input)
158-
circle_model.save(f"tinyllama.attn.circle")
170+
circle_model.save(f"tinyllama.layer.attn.circle")

test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def populate_args(args_dict, filter):
8787
@torch.library.impl("circle::attention.llama", "CPU")
8888
def attention_llama_cpu(
8989
hidden_states,
90+
q_proj,
91+
k_proj,
92+
v_proj,
93+
o_proj,
9094
position_cos,
9195
position_sin,
9296
attention_mask,
@@ -102,6 +106,10 @@ def attention_llama_cpu(
102106
def attention_llama(*args, **kwargs):
103107
(
104108
hidden_states,
109+
q_proj,
110+
k_proj,
111+
v_proj,
112+
o_proj,
105113
position_cos,
106114
position_sin,
107115
attention_mask,
@@ -133,6 +141,10 @@ def forward_adapter(
133141
return (
134142
torch.ops.circle.attention.llama(
135143
hidden_states,
144+
self.q_proj.weight,
145+
self.k_proj.weight,
146+
self.v_proj.weight,
147+
self.o_proj.weight,
136148
position_embeddings[0], # cos
137149
position_embeddings[1], # sin
138150
attention_mask,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import TYPE_CHECKING
16+
17+
if TYPE_CHECKING:
18+
import torch.fx
19+
import torch
20+
from torch.export import ExportedProgram
21+
22+
from tico.passes import ops
23+
from tico.utils import logging
24+
from tico.utils.passes import PassBase, PassResult
25+
from tico.utils.trace_decorators import trace_graph_diff_on_pass
26+
27+
28+
@trace_graph_diff_on_pass
29+
class RemoveUnusedInput(PassBase):
30+
"""
31+
Let's remove dead inputs
32+
"""
33+
34+
def __init__(self):
35+
super().__init__()
36+
37+
def call(self, exported_program: ExportedProgram) -> PassResult:
38+
logger = logging.getLogger(__name__)
39+
40+
graph_module = exported_program.graph_module
41+
graph = graph_module.graph
42+
modified = False
43+
for node in graph.nodes:
44+
if node.op == "placeholder" and len(node.users) == 0:
45+
graph.erase_node(node)
46+
modified = True
47+
48+
graph.lint()
49+
graph_module.recompile()
50+
51+
return PassResult(modified)

tico/serialize/operators/op_circle_attention.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
"""
3333
attention.llama(
3434
Tensor hidden_states,
35+
Tensor wq,
36+
Tensor wk,
37+
Tensor wv,
38+
Tensor wo,
3539
Tensor position_cos,
3640
Tensor position_sin,
3741
Tensor? attention_mask,
@@ -59,6 +63,10 @@ def define_node(
5963
) -> circle.Operator.OperatorT:
6064
(
6165
hidden_states,
66+
wq,
67+
wk,
68+
wv,
69+
wo,
6270
position_cos,
6371
position_sin,
6472
attention_mask,
@@ -68,9 +76,6 @@ def define_node(
6876
layer_idx,
6977
) = node.args
7078

71-
inputs = node.args
72-
outputs = [node]
73-
7479
op_index = get_op_index(
7580
circle.BuiltinOperator.BuiltinOperator.ATTENTION, self._op_codes
7681
)

tico/utils/convert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from tico.passes.lower_to_slice import passes as LowerToSlicePasses
6464
from tico.passes.merge_consecutive_cat import MergeConsecutiveCat
6565
from tico.passes.remove_nop import RemoveNop
66+
from tico.passes.remove_unused_inputs import RemoveUnusedInput
6667
from tico.passes.remove_redundant_assert_nodes import RemoveRedundantAssertionNodes
6768
from tico.passes.remove_redundant_expand import RemoveRedundantExpand
6869
from tico.passes.remove_redundant_permute import passes as RemoveRedundantPermutePasses
@@ -251,6 +252,7 @@ def convert_exported_module_to_circle(
251252
ConvertConv1dToConv2d(),
252253
*LowerToSlicePasses(),
253254
FuseLeadingUnsqueezeReshape(),
255+
RemoveUnusedInput(),
254256
]
255257
)
256258
circle_legalize.run(exported_program)

0 commit comments

Comments
 (0)