77
88import torch
99from executorch .exir .pass_base import ExportPass , PassResult
10+ from torch ._export .utils import is_buffer , is_param
1011
1112
1213class UnsqueezeScalarPlaceholdersPass (ExportPass ):
@@ -19,23 +20,27 @@ def __init__(self, exported_program):
1920 self .exported_program = exported_program
2021 super ().__init__ ()
2122
22- def _is_inputs_to_buffers_or_parameters (self , node ):
23- return (
24- node .name in self .exported_program .graph_signature .inputs_to_buffers
25- or node .name in self .exported_program .graph_signature .inputs_to_parameters
26- )
27-
2823 def call (self , graph_module : torch .fx .GraphModule ):
2924 for node in graph_module .graph .nodes :
3025 if node .op != "placeholder" :
3126 continue
3227 rank = node .meta ["val" ].dim ()
3328 if rank == 0 :
34- if not self ._is_inputs_to_buffers_or_parameters (node ):
29+ if is_buffer (self .exported_program , node ):
30+ name = self .exported_program .graph_signature .inputs_to_buffers [
31+ node .name
32+ ]
33+ elif is_param (self .exported_program , node ):
34+ name = self .exported_program .graph_signature .inputs_to_parameters [
35+ node .name
36+ ]
37+ else :
3538 continue
36- tensor = self .exported_program .state_dict [node .name ]
39+
40+ tensor = self .exported_program .state_dict [name ]
41+
3742 if tensor .dim () == 0 :
38- self .exported_program .state_dict [node . name ] = tensor .unsqueeze (0 )
43+ self .exported_program .state_dict [name ] = tensor .unsqueeze (0 )
3944 node .meta ["val" ] = node .meta ["val" ].fake_mode .from_tensor (
4045 tensor .unsqueeze (0 ), static_shapes = True
4146 )
@@ -53,6 +58,9 @@ def ensures(self, graph_module: torch.fx.GraphModule):
5358 if node .op == "placeholder" :
5459 rank = node .meta ["val" ].dim ()
5560 if rank == 0 :
56- if not self ._is_inputs_to_buffers_or_parameters (node ):
61+ if not (
62+ is_buffer (self .exported_program , node )
63+ or is_param (self .exported_program , node )
64+ ):
5765 continue
5866 raise ValueError ("Placeholders of rank 0 are not supported!" )
0 commit comments