|
9 | 9 | # See the License for the specific language governing permissions and |
10 | 10 | # limitations under the License. |
11 | 11 |
|
12 | | -from typing import Callable, Dict, List, Optional, Tuple, Union |
| 12 | +from typing import Callable, Dict, List, Optional, Union |
13 | 13 |
|
14 | 14 | import numpy as np |
15 | 15 | import onnx |
@@ -254,30 +254,6 @@ def get_bias_tensor_port_id(self, node: onnx.NodeProto) -> int: |
254 | 254 | return weight_definitions.bias_port_id |
255 | 255 | raise RuntimeError(f"The node {node} does not have bias_port_id attribute") |
256 | 256 |
|
257 | | - def _get_weight_tensor_with_reshape(self, node: onnx.NodeProto) -> Tuple[str, np.ndarray]: |
258 | | - """ |
259 | | - Returns node's weight tensor name and its value in the case when reshape node is placed after the weight. |
260 | | - The returned weight tensor will be reshaped according to a shape attribute of the reshape node. |
261 | | -
|
262 | | - :param node: Reshape node, whose input is weight tensor. |
263 | | - :return: The weight tensor name and its value with applied the reshape operation. |
264 | | - """ |
265 | | - tensor_name = node.output[0] |
266 | | - shape = self.get_initializers_value(node.input[1]) |
267 | | - tensor_value = self.get_initializers_value(node.input[0]) |
268 | | - reshaped_tensor_value = tensor_value.reshape(shape) |
269 | | - return tensor_name, reshaped_tensor_value |
270 | | - |
271 | | - def _get_tensor_from_zero_input(self, node: onnx.NodeProto) -> Tuple[str, np.ndarray]: |
272 | | - """ |
273 | | - Returns the weight tensor name and its value, which is located on the 0-index input port of the node. |
274 | | -
|
275 | | - :param node: Node, which takes on the 0-index input port id the weight tensor. |
276 | | - :return: The weight tensor name and its value. |
277 | | - """ |
278 | | - tensor_name = self.get_initializer(node.input[0]).name |
279 | | - return tensor_name, self.get_initializers_value(tensor_name) |
280 | | - |
281 | 257 | def get_node_index(self, node_name: str) -> int: |
282 | 258 | """ |
283 | 259 | Returns the node index in the model. |
|
0 commit comments