7
7
8
8
import torch
9
9
from executorch .exir .pass_base import ExportPass , PassResult
10
+ from torch ._export .utils import is_buffer , is_param
10
11
11
12
12
13
class UnsqueezeScalarPlaceholdersPass (ExportPass ):
@@ -19,23 +20,27 @@ def __init__(self, exported_program):
19
20
self .exported_program = exported_program
20
21
super ().__init__ ()
21
22
22
- def _is_inputs_to_buffers_or_parameters (self , node ):
23
- return (
24
- node .name in self .exported_program .graph_signature .inputs_to_buffers
25
- or node .name in self .exported_program .graph_signature .inputs_to_parameters
26
- )
27
-
28
23
def call (self , graph_module : torch .fx .GraphModule ):
29
24
for node in graph_module .graph .nodes :
30
25
if node .op != "placeholder" :
31
26
continue
32
27
rank = node .meta ["val" ].dim ()
33
28
if rank == 0 :
34
- if not self ._is_inputs_to_buffers_or_parameters (node ):
29
+ if is_buffer (self .exported_program , node ):
30
+ name = self .exported_program .graph_signature .inputs_to_buffers [
31
+ node .name
32
+ ]
33
+ elif is_param (self .exported_program , node ):
34
+ name = self .exported_program .graph_signature .inputs_to_parameters [
35
+ node .name
36
+ ]
37
+ else :
35
38
continue
36
- tensor = self .exported_program .state_dict [node .name ]
39
+
40
+ tensor = self .exported_program .state_dict [name ]
41
+
37
42
if tensor .dim () == 0 :
38
- self .exported_program .state_dict [node . name ] = tensor .unsqueeze (0 )
43
+ self .exported_program .state_dict [name ] = tensor .unsqueeze (0 )
39
44
node .meta ["val" ] = node .meta ["val" ].fake_mode .from_tensor (
40
45
tensor .unsqueeze (0 ), static_shapes = True
41
46
)
@@ -53,6 +58,9 @@ def ensures(self, graph_module: torch.fx.GraphModule):
53
58
if node .op == "placeholder" :
54
59
rank = node .meta ["val" ].dim ()
55
60
if rank == 0 :
56
- if not self ._is_inputs_to_buffers_or_parameters (node ):
61
+ if not (
62
+ is_buffer (self .exported_program , node )
63
+ or is_param (self .exported_program , node )
64
+ ):
57
65
continue
58
66
raise ValueError ("Placeholders of rank 0 are not supported!" )
0 commit comments