|
16 | 16 | get_buffer, |
17 | 17 | get_lifted_tensor_constant, |
18 | 18 | get_param, |
19 | | - is_buffer, |
20 | 19 | is_lifted_tensor_constant, |
21 | 20 | is_param, |
22 | 21 | ) |
@@ -78,22 +77,29 @@ def get_data( |
78 | 77 | return None |
79 | 78 |
|
80 | 79 |
|
| 80 | +def is_constant_buffer(program: "ExportedProgram", node: torch.fx.Node) -> bool: |
| 81 | + """Checks if the given node is a constant buffer.""" |
| 82 | + |
| 83 | + if node.target not in program.graph_signature.inputs_to_buffers: |
| 84 | + return False |
| 85 | + fqn = program.graph_signature.inputs_to_buffers[node.target] |
| 86 | + # if the buffer is mutated then record that |
| 87 | + return fqn not in program.graph_signature.buffers_to_mutate.values() |
| 88 | + |
| 89 | + |
81 | 90 | def get_constant_placeholder_dict( |
82 | 91 | exported_program: ExportedProgram, |
83 | 92 | ) -> OrderedDict[torch.fx.Node, torch.Tensor]: |
84 | 93 | """ |
85 | 94 | Returns a dictionary of placeholder node -> constant tensor. |
86 | 95 | """ |
87 | 96 | const_node_to_tensor: OrderedDict[torch.fx.Node, torch.Tensor] = OrderedDict() |
88 | | - for node in exported_program.graph.nodes: |
89 | | - if node.op != "placeholder": |
90 | | - continue |
91 | | - |
| 97 | + for node in exported_program.graph.find_nodes(op="placeholder"): |
92 | 98 | if is_param(exported_program, node): |
93 | 99 | const_node_to_tensor[node] = cast( |
94 | 100 | torch.Tensor, get_param(exported_program, node) |
95 | 101 | ) |
96 | | - elif is_buffer(exported_program, node): |
| 102 | + elif is_constant_buffer(exported_program, node): |
97 | 103 | const_node_to_tensor[node] = cast( |
98 | 104 | torch.Tensor, get_buffer(exported_program, node) |
99 | 105 | ) |
|
0 commit comments