@@ -13,35 +13,9 @@ class InitMutableBufferPass(ExportPass):
1313 def __init__ (self ) -> None :
1414 super ().__init__ ()
1515
16- def update_placeholder_tensor_specs (
17- self ,
18- exported_program : torch .export .ExportedProgram ,
19- graph_module : torch .fx .GraphModule ,
20- ) -> None :
21- """
22- Update the tensor specs for all placeholder nodes such that
23- placeholders that are parameters are marked as constant.
24- """
25- for node in graph_module .graph .nodes :
26- if node .op != "placeholder" :
27- continue
28- if "spec" not in node .meta :
29- raise RuntimeError (f"Placeholder node { node } missing meta['spec']" )
30- # print(node)
31- spec = node .meta ["spec" ]
32- if (isinstance (node .target , str ) and
33- node .target in exported_program .graph_signature .inputs_to_buffers and exported_program .graph_signature .inputs_to_buffers [node .target ] in exported_program .state_dict ):
34- # print(f"Setting {node.target}.const = True")
35- # breakpoint()
36- # print(exported_program.state_dict[exported_program.graph_signature.inputs_to_buffers[node.target]])
37- spec .const = True
38-
39- # pyre-ignore
4016 def placeholder (self , name : str , arg , meta ):
41- # print(name)
42- meta ["spec" ] = make_spec (arg , const = meta .data ['spec' ].const )
43- # if name == "b_kv_cache_cache_pos":
44- # print("breakpoint")
45- # breakpoint()
46-
17+ if "cache_pos" in name :
18+ meta ["et_init_buffer" ] = True
19+
4720 return super ().placeholder (name , arg , meta )
21+
0 commit comments