9
9
from executorch .exir .pass_base import ExportPass , PassResult
10
10
from executorch .exir .passes import dead_code_elimination_pass
11
11
12
- from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
13
-
14
12
15
13
class FixedLinearKeepDim (ExportPass ):
16
14
"""
@@ -24,61 +22,58 @@ def __init__(self):
24
22
super (FixedLinearKeepDim , self ).__init__ ()
25
23
26
24
def _fixed_keep_dim (self , graph_module : torch .fx .GraphModule ):
27
- partitions = get_source_partitions (
28
- graph_module .graph , [torch .nn .Linear , torch .ops .aten .linear .default ]
29
- )
30
- for _ , src_partitions in partitions .items ():
31
- for src_partition in src_partitions :
32
- linear_node = [
33
- n for n in src_partition .nodes if n .target == self .linear
34
- ][0 ]
35
- input_node = linear_node .args [0 ]
36
- # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
37
- # TODO: Find a more general conditional statement.
38
- linear_output = linear_node .meta ["val" ]
39
- if linear_output .dim () >= 3 :
40
- with graph_module .graph .inserting_after (input_node ):
41
- input_users = list (input_node .users .keys ())
42
- input_tensor = input_node .meta ["val" ]
43
- squeeze_dim = (- 1 , input_tensor .shape [- 1 ])
44
- squeeze_node = graph_module .graph .create_node (
45
- "call_function" ,
46
- self .view_copy ,
47
- (
48
- input_node ,
49
- squeeze_dim ,
50
- ),
51
- )
52
- # meta needs to be copied elementwisely for fake-tensor
53
- # to be updated correctly and not affect meta of input_node
54
- for k , v in input_node .meta .items ():
55
- squeeze_node .meta [k ] = v
56
- squeeze_node .meta ["val" ] = input_tensor .reshape (squeeze_dim )
57
- for user in input_users :
58
- if user == linear_node :
59
- user .replace_input_with (input_node , squeeze_node )
25
+ for node in graph_module .graph .nodes :
26
+ if node .target != self .linear :
27
+ continue
28
+
29
+ linear_node = node
30
+ input_node = linear_node .args [0 ]
31
+ # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
32
+ # TODO: Find a more general conditional statement.
33
+ linear_output = linear_node .meta ["val" ]
34
+ if linear_output .dim () >= 3 :
35
+ with graph_module .graph .inserting_after (input_node ):
36
+ input_users = list (input_node .users .keys ())
37
+ input_tensor = input_node .meta ["val" ]
38
+ squeeze_dim = (- 1 , input_tensor .shape [- 1 ])
39
+ squeeze_node = graph_module .graph .create_node (
40
+ "call_function" ,
41
+ self .view_copy ,
42
+ (
43
+ input_node ,
44
+ squeeze_dim ,
45
+ ),
46
+ )
47
+ # meta needs to be copied elementwisely for fake-tensor
48
+ # to be updated correctly and not affect meta of input_node
49
+ for k , v in input_node .meta .items ():
50
+ squeeze_node .meta [k ] = v
51
+ squeeze_node .meta ["val" ] = input_tensor .reshape (squeeze_dim )
52
+ for user in input_users :
53
+ if user == linear_node :
54
+ user .replace_input_with (input_node , squeeze_node )
60
55
61
- with graph_module .graph .inserting_after (linear_node ):
62
- output_users = list (linear_node .users .keys ())
63
- unsqueeze_dim = linear_output .shape
64
- unsqueeze_node = graph_module .graph .create_node (
65
- "call_function" ,
66
- self .view_copy ,
67
- (
68
- linear_node ,
69
- unsqueeze_dim ,
70
- ),
71
- )
72
- # meta needs to be copied elementwisely for fake-tensor
73
- # to be updated correctly and not affect meta of unsqueeze_node
74
- for k , v in linear_node .meta .items ():
75
- unsqueeze_node .meta [k ] = v
76
- # update linear node's shape
77
- linear_node .meta ["val" ] = linear_output .reshape (
78
- (squeeze_node .meta ["val" ].shape [0 ], linear_output .shape [- 1 ])
79
- )
80
- for user in output_users :
81
- user .replace_input_with (linear_node , unsqueeze_node )
56
+ with graph_module .graph .inserting_after (linear_node ):
57
+ output_users = list (linear_node .users .keys ())
58
+ unsqueeze_dim = linear_output .shape
59
+ unsqueeze_node = graph_module .graph .create_node (
60
+ "call_function" ,
61
+ self .view_copy ,
62
+ (
63
+ linear_node ,
64
+ unsqueeze_dim ,
65
+ ),
66
+ )
67
+ # meta needs to be copied elementwisely for fake-tensor
68
+ # to be updated correctly and not affect meta of unsqueeze_node
69
+ for k , v in linear_node .meta .items ():
70
+ unsqueeze_node .meta [k ] = v
71
+ # update linear node's shape
72
+ linear_node .meta ["val" ] = linear_output .reshape (
73
+ (squeeze_node .meta ["val" ].shape [0 ], linear_output .shape [- 1 ])
74
+ )
75
+ for user in output_users :
76
+ user .replace_input_with (linear_node , unsqueeze_node )
82
77
83
78
def call (self , graph_module : torch .fx .GraphModule ):
84
79
self ._fixed_keep_dim (graph_module )
0 commit comments