@@ -36,11 +36,11 @@ def process_call_function(
3636 tosa_spec : TosaSpecification ,
3737):
3838 # Unpack arguments and convert
39- inputs = getNodeArgs (node )
39+ inputs = getNodeArgs (node , tosa_spec )
4040
4141 # Convert output (this node itself)
4242 try :
43- output = TosaArg (node )
43+ output = TosaArg (node , tosa_spec )
4444 except ValueError as e :
4545 raise ValueError (
4646 f"Failed processing call_function: { node .name } . "
@@ -78,7 +78,7 @@ def process_inputs(
7878 f"Expected dim_order: { tuple (range (meta .dim ()))} , but got: { meta .dim_order ()} for node { node .name } "
7979 )
8080 try :
81- tosa_arg = TosaArg (node )
81+ tosa_arg = TosaArg (node , tosa_spec )
8282 except ValueError as e :
8383 raise ValueError (
8484 f"Failed processing input placeholder: { node .name } . "
@@ -112,7 +112,7 @@ def process_inputs_to_parameters(
112112):
113113 """Serialize bias and non-quantized weights"""
114114 try :
115- tosa_arg = TosaArg (node )
115+ tosa_arg = TosaArg (node , tosa_spec )
116116 except ValueError as e :
117117 raise ValueError (
118118 f"Failed processing parameter placeholder: { node .name } . "
@@ -137,10 +137,11 @@ def process_inputs_to_buffers(
137137 node : torch .fx .Node ,
138138 tosa_graph : Any ,
139139 edge_program : ExportedProgram ,
140+ tosa_spec : TosaSpecification ,
140141):
141142 """Serialize quantized weights"""
142143 try :
143- tosa_arg = TosaArg (node )
144+ tosa_arg = TosaArg (node , tosa_spec )
144145 except ValueError as e :
145146 raise ValueError (
146147 f"Failed processing buffer placeholder: { node .name } . "
@@ -165,9 +166,10 @@ def process_inputs_to_lifted_tensor_constants(
165166 node : torch .fx .Node ,
166167 tosa_graph : Any ,
167168 edge_program : ExportedProgram ,
169+ tosa_spec : TosaSpecification ,
168170):
169171 try :
170- tosa_arg = TosaArg (node )
172+ tosa_arg = TosaArg (node , tosa_spec )
171173 except ValueError as e :
172174 raise ValueError (
173175 f"Failed processing lifted tensor constant placeholder: { node .name } . "
@@ -196,9 +198,11 @@ def process_placeholder(
196198 elif is_param (edge_program , node ):
197199 process_inputs_to_parameters (node , tosa_graph , edge_program , tosa_spec )
198200 elif is_buffer (edge_program , node ):
199- process_inputs_to_buffers (node , tosa_graph , edge_program )
201+ process_inputs_to_buffers (node , tosa_graph , edge_program , tosa_spec )
200202 elif is_lifted_tensor_constant (edge_program , node ):
201- process_inputs_to_lifted_tensor_constants (node , tosa_graph , edge_program )
203+ process_inputs_to_lifted_tensor_constants (
204+ node , tosa_graph , edge_program , tosa_spec
205+ )
202206 elif node .name in edge_program .graph_signature .inputs_to_lifted_custom_objs :
203207 raise NotImplementedError (
204208 "Placeholder is of type 'lifted custom object' which is not supported."
0 commit comments