@@ -268,13 +268,42 @@ def lift_constant_tensor_pass(ep):
268268 buffers = list (graph_signature .buffers )
269269
270270 fake_mode = list (ep .graph .nodes )[0 ].meta ["val" ].fake_mode
271- first_user_input = None
271+ insert_before_node = None
272272 lifted_constants = []
273273 for node in ep .graph .nodes :
274274 if node .op == "placeholder" and node .name in graph_signature .user_inputs :
275- first_user_input = node
275+ insert_before_node = node # first user input
276276 break
277277
278+ if insert_before_node is None :
279+ # we have no user inputs, find the node after the last buffer
280+ # (that we will insert the lifted constants before).
281+ # this is a bit hacky, but I am not certain of what the contract is for
282+ # node ordering. is the first non-placeholder node guranteed to be the
283+ # first node after input paramters? what if there is no op, and it is
284+ # just placeholders? Easier to just find the last buffer, and insert after.
285+
286+ # also error if we have no buffers and no user inputs... if that is an issue, fix it later?
287+ last_buffer = None
288+ for node in ep .graph .nodes :
289+ node_buffer_fqn = graph_signature .inputs_to_buffers .get (node .name , None )
290+ # not sure if both cases are needed, if is it possible to encounter a
291+ # buffer that is not a user input?
292+ if (
293+ node_buffer_fqn is not None
294+ and node_buffer_fqn in graph_signature .buffers
295+ ):
296+ last_buffer = node
297+ continue
298+ if node .op == "placeholder" and node .name in graph_signature .buffers :
299+ last_buffer = node
300+ continue
301+ # we have our last buffer, grab the node after it, to insert the lifted constants before.
302+ insert_before_node = last_buffer .next
303+
304+ if insert_before_node is None :
305+ raise ValueError ("No user inputs and no buffers found. Cannot lift constants." )
306+
278307 for node in ep .graph .nodes :
279308 if node .op == "get_attr" :
280309 constant_tensor = getattr (ep .graph_module , node .target )
@@ -283,7 +312,7 @@ def lift_constant_tensor_pass(ep):
283312
284313 constant_tensor_fqn = f"_lifted_tensor_constant{ len (buffers )} "
285314
286- with ep .graph .inserting_before (first_user_input ):
315+ with ep .graph .inserting_before (insert_before_node ):
287316 # Insert the constant node before the first user input
288317 const_placeholder_node = ep .graph .placeholder (constant_tensor_fqn )
289318 for k , v in node .meta .items ():
@@ -316,6 +345,9 @@ def lift_constant_tensor_pass(ep):
316345 new_input_specs .extend (lifted_constants )
317346 lifted_constants .clear ()
318347 new_input_specs .append (s )
348+ # Add remaining lifted constants if no user inputs exist.
349+ if len (lifted_constants ) > 0 :
350+ new_input_specs .extend (lifted_constants )
319351 ep .graph_signature .input_specs = new_input_specs
320352 ep .graph_module .recompile ()
321353 return ep
0 commit comments