|
36 | 36 | get_parameter, |
37 | 37 | is_graph_input, |
38 | 38 | is_graph_output, |
| 39 | + is_mutable_buffer_input, |
| 40 | + is_mutable_buffer_output, |
39 | 41 | is_parameter, |
40 | 42 | ) |
41 | 43 |
|
@@ -214,7 +216,9 @@ def get_tensor_type( |
214 | 216 | node: torch.fx.Node, |
215 | 217 | tensor_type: PyQnnWrapper.Qnn_TensorType_t, |
216 | 218 | ) -> PyQnnWrapper.Qnn_TensorType_t: |
217 | | - is_input = is_graph_input(node, self.edge_program) |
| 219 | + is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input( |
| 220 | + node, self.edge_program |
| 221 | + ) |
218 | 222 | is_output = is_graph_output(node) |
219 | 223 | # handle logic for input/output tensors |
220 | 224 | if is_input or is_output: |
@@ -247,6 +251,33 @@ def get_data_type( |
247 | 251 |
|
248 | 252 | return QNN_TENSOR_TYPE_MAP[tensor.dtype] |
249 | 253 |
|
| 254 | + def get_tensor_name( |
| 255 | + self, |
| 256 | + node: torch.fx.Node, |
| 257 | + wrapper_idx: int = 0, |
| 258 | + ): |
| 259 | + tensor_name = f"{node.name}_{wrapper_idx}" |
| 260 | + # The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess, |
| 261 | + # the input order between QNN and the original graph’s forward function may differ. |
| 262 | + # The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime. |
| 263 | + # The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump. |
| 264 | + if is_mutable_buffer_input(node, self.edge_program): |
| 265 | + fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target] |
| 266 | + position_index = list( |
| 267 | + self.edge_program.graph_signature.buffers_to_mutate.values() |
| 268 | + ).index(fqn) |
| 269 | + tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}" |
| 270 | + elif is_graph_input(node, self.edge_program): |
| 271 | + tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}" |
| 272 | + elif is_mutable_buffer_output(node, self.edge_program): |
| 273 | + position_index = list( |
| 274 | + self.edge_program.graph_signature.buffers_to_mutate.keys() |
| 275 | + ).index(node.name) |
| 276 | + tensor_name = f"output_mutbuf_{position_index}_{tensor_name}" |
| 277 | + elif is_graph_output(node): |
| 278 | + tensor_name = f"output_{tensor_name}" |
| 279 | + return tensor_name |
| 280 | + |
250 | 281 | def define_custom_tensor_wrapper( |
251 | 282 | self, |
252 | 283 | node_name: str, |
@@ -307,11 +338,7 @@ def define_tensor( |
307 | 338 | if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): |
308 | 339 | return cached |
309 | 340 |
|
310 | | - tensor_name = f"{node.name}_{wrapper_idx}" |
311 | | - if is_graph_input(node, self.edge_program): |
312 | | - tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name |
313 | | - if is_graph_output(node): |
314 | | - tensor_name = "output_" + tensor_name |
| 341 | + tensor_name = self.get_tensor_name(node, wrapper_idx) |
315 | 342 | dims = [1] if len(tensor.size()) == 0 else tensor.size() |
316 | 343 | tensor_type = self.get_tensor_type(node, tensor_type) |
317 | 344 | quant_encoding, quant_configs = self.get_quant_encoding_conf( |
@@ -383,7 +410,9 @@ def generate_node_to_external_map( |
383 | 410 | # The order in which we visit the placeholder node is same as the *args |
384 | 411 | # order for the forward(*args) signature for this gm. Using the order of |
385 | 412 | # the nodes as external_id to extract the right arg from *args at runtime |
386 | | - if is_graph_input(node, edge_program): |
| 413 | + if is_graph_input(node, edge_program) or is_mutable_buffer_input( |
| 414 | + node, edge_program |
| 415 | + ): |
387 | 416 | node_to_external_map[node] = len(node_to_external_map) |
388 | 417 | for node in edge_program.graph_module.graph.nodes: |
389 | 418 | if is_graph_output(node): |
|
0 commit comments