11import torch
2+ import inspect
23
34
45def parse_sole_graph_module (module , inputs ):
@@ -14,12 +15,46 @@ def my_backend(gm, sample_inputs):
1415
1516 torch .compile (module , backend = my_backend )(* inputs )
1617 assert traced_module is not None
17- assert all (id (a ) == id (b ) for a , b in zip (inputs , traced_sample_inputs ))
1818 for node in traced_module .graph .nodes :
1919 if node .op != "placeholder" :
2020 continue
2121 assert node .target [:2 ] == "L_" or node .target [:2 ] == "l_" , f"{ node .target = } "
2222 node .target = node .target [2 :]
23+ if node .target [0 ] == "l" :
24+ node .target = "L" + node .target [1 :]
2325 assert node .name [:2 ] == "L_" or node .name [:2 ] == "l_" , f"{ node .name = } "
2426 node .name = node .name [2 :]
27+ if node .name [0 ] == "l" :
28+ node .name = "L" + node .name [1 :]
29+
30+ def get_input_names_from_signature ():
31+ return inspect .signature (module .forward ).parameters
32+
33+ def get_input_names_from_placeholder ():
34+ return [
35+ node .name for node in traced_module .graph .nodes if node .op == "placeholder"
36+ ]
37+
38+ def get_diff_input_names ():
39+ placeholder_names = set (get_input_names_from_placeholder ())
40+ return [
41+ (i , name )
42+ for i , name in enumerate (get_input_names_from_signature ())
43+ if name not in placeholder_names
44+ ]
45+
46+ if len (inputs ) == len (traced_sample_inputs ) + 1 :
47+ diff_input_names = get_diff_input_names ()
48+ assert len (diff_input_names ) == 1 , f"{ diff_input_names = } "
49+ pos , name = diff_input_names [0 ]
50+ for i , node in enumerate (traced_module .graph .nodes ):
51+ if i < pos :
52+ assert node .op == "placeholder"
53+ elif i == pos :
54+ with traced_module .graph .inserting_before (node ):
55+ traced_module .graph .placeholder (name )
56+ else :
57+ break
58+ traced_module .recompile ()
59+ assert len (get_diff_input_names ()) == 0 , f"{ get_diff_input_names ()= } "
2560 return traced_module
0 commit comments