|
41 | 41 | get_parameter, |
42 | 42 | is_graph_input, |
43 | 43 | is_graph_output, |
| 44 | + is_mutable_buffer_input, |
| 45 | + is_mutable_buffer_output, |
44 | 46 | is_parameter, |
45 | 47 | ) |
46 | 48 |
|
@@ -307,7 +309,9 @@ def get_tensor_type( |
307 | 309 | node: torch.fx.Node, |
308 | 310 | tensor_type: PyQnnWrapper.Qnn_TensorType_t, |
309 | 311 | ) -> PyQnnWrapper.Qnn_TensorType_t: |
310 | | - is_input = is_graph_input(node, self.edge_program) |
| 312 | + is_input = is_graph_input(node, self.edge_program) or is_mutable_buffer_input( |
| 313 | + node, self.edge_program |
| 314 | + ) |
311 | 315 | is_output = is_graph_output(node) |
312 | 316 | # handle logic for input/output tensors |
313 | 317 | if is_input or is_output: |
@@ -352,6 +356,33 @@ def get_dynamic_dimension(self, dims): |
352 | 356 |
|
353 | 357 | return dynamic_dims if any(dynamic_dims) else [], nominal_dims |
354 | 358 |
|
| 359 | + def get_tensor_name( |
| 360 | + self, |
| 361 | + node: torch.fx.Node, |
| 362 | + wrapper_idx: int = 0, |
| 363 | + ): |
| 364 | + tensor_name = f"{node.name}_{wrapper_idx}" |
| 365 | + # The `input_{id}` is utilized for sorting at runtime. Due to multiple passes in qnn_preprocess, |
| 366 | + # the input order between QNN and the original graph’s forward function may differ. |
| 367 | + # The `mutbuf_{id}` is utilized for mapping I/O of mutable buffer at runtime. |
| 368 | + # The `output_` is identified as the graph’s output at runtime to prevent confusion with per_tensor_dump. |
| 369 | + if is_mutable_buffer_input(node, self.edge_program): |
| 370 | + fqn = self.edge_program.graph_signature.inputs_to_buffers[node.target] |
| 371 | + position_index = list( |
| 372 | + self.edge_program.graph_signature.buffers_to_mutate.values() |
| 373 | + ).index(fqn) |
| 374 | + tensor_name = f"input_{str(self.external_ids[node])}_mutbuf_{str(position_index)}_{tensor_name}" |
| 375 | + elif is_graph_input(node, self.edge_program): |
| 376 | + tensor_name = f"input_{str(self.external_ids[node])}_{tensor_name}" |
| 377 | + elif is_mutable_buffer_output(node, self.edge_program): |
| 378 | + position_index = list( |
| 379 | + self.edge_program.graph_signature.buffers_to_mutate.keys() |
| 380 | + ).index(node.name) |
| 381 | + tensor_name = f"output_mutbuf_{position_index}_{tensor_name}" |
| 382 | + elif is_graph_output(node): |
| 383 | + tensor_name = f"output_{tensor_name}" |
| 384 | + return tensor_name |
| 385 | + |
355 | 386 | def define_custom_tensor_wrapper( |
356 | 387 | self, |
357 | 388 | node_name: str, |
@@ -413,16 +444,7 @@ def define_tensor( |
413 | 444 | if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): |
414 | 445 | return cached |
415 | 446 |
|
416 | | - tensor_name = f"{tensor_source_node.name}_{wrapper_idx}" |
417 | | - if is_graph_input(tensor_source_node, self.edge_program): |
418 | | - tensor_name = ( |
419 | | - "input_" |
420 | | - + str(self.external_ids[tensor_source_node]) |
421 | | - + "_" |
422 | | - + tensor_name |
423 | | - ) |
424 | | - if is_graph_output(tensor_source_node): |
425 | | - tensor_name = "output_" + tensor_name |
| 447 | + tensor_name = self.get_tensor_name(tensor_source_node, wrapper_idx) |
426 | 448 | dims = torch.Size([1]) if len(tensor.size()) == 0 else tensor.size() |
427 | 449 | dynamic_dims, nominal_dims = self.get_dynamic_dimension(dims) |
428 | 450 | tensor_type = self.get_tensor_type(tensor_source_node, tensor_type) |
|
0 commit comments