Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.hessian import HessianScoresRequest
from model_compression_toolkit.core.pytorch.utils import is_tuple_of_tensors
from model_compression_toolkit.logger import Logger


Expand Down Expand Up @@ -85,6 +86,9 @@ def unfold_tensors_list(tensors_to_unfold: Any) -> List[Any]:
"""
unfold_tensors = []
for tensor in tensors_to_unfold:
if is_tuple_of_tensors(tensor):
tensor = list(tensor) # converts named tuple to list

if isinstance(tensor, List):
unfold_tensors += tensor
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ def substitute(self, graph: Graph, attention_node: FunctionalNode) -> Graph:
:param attention_node: the node to replace
:return: A graph after the substitution
"""
print("In scale_dot_product_attention substitution@@@@@@@@")
input_nodes = self._get_attention_input_nodes(graph, attention_node)
q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
transpose_k_node = self._get_transpose_k_node(attention_node.name, k_node)
Expand Down
35 changes: 28 additions & 7 deletions model_compression_toolkit/core/pytorch/reader/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,40 @@
# limitations under the License.
# ==============================================================================


import logging
from typing import Callable, Dict

import numpy as np
import torch
from torch.fx import symbolic_trace
import logging
from typing import Callable, Dict, Union, Any
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx import Tracer, GraphModule, symbolic_trace

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.pytorch.reader.graph_builders import edges_builder, nodes_builder
from model_compression_toolkit.core.pytorch.utils import set_model
from sony_custom_layers.pytorch import CustomLayer


def _trace_model(root: Union[torch.nn.Module, Callable[..., Any]]) -> GraphModule:
"""
Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule``
constructed by recording operations seen while tracing through ``root``.
This function replaces torch.fx.symbolic_trace in order to handle custom layers tracing - treating them as graph
leafs.
:param root: Module or function to be traced and converted into a Graph representation.
:return: GraphModule: a Module created from the recorded operations from ``root``.
"""

class MCTTracer(Tracer):
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
if isinstance(m, CustomLayer):
return True
return super().is_leaf_module(m, module_qualified_name)

tracer = MCTTracer()
graph = tracer.trace(root)
# handling the possibility that the model (root) might be a torch.nn.Module or a function
model_name = (root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__)
return GraphModule(tracer.root, graph, model_name)


def generate_module_dict(model: torch.nn.Module) -> Dict:
Expand Down Expand Up @@ -87,7 +108,7 @@ def fx_graph_module_generation(pytorch_model: torch.nn.Module,
set_model(pytorch_model)

try:
symbolic_traced = symbolic_trace(pytorch_model)
symbolic_traced = _trace_model(pytorch_model)
except torch.fx.proxy.TraceError as e:
Logger.critical(f'Error parsing model with torch.fx\n'
f'fx error: {e}')
Expand Down
17 changes: 15 additions & 2 deletions model_compression_toolkit/core/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
from torch import Tensor
import numpy as np
from typing import Union, Sequence, Optional, List, Tuple
from typing import Union, Optional, List, Tuple, Any

from model_compression_toolkit.core.pytorch.constants import MAX_FLOAT16, MIN_FLOAT16
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
Expand Down Expand Up @@ -112,4 +112,17 @@ def clip_inf_values_float16(tensor: Tensor) -> Tensor:
# Replace inf values with max float16 value
tensor[inf_mask] = MAX_FLOAT16 * torch.sign(tensor[inf_mask])

return tensor
return tensor


def is_tuple_of_tensors(obj: Any) -> bool:
"""
:param obj: Object to check its type
:return: True if obj is a tuple of tensors, False otherwise
"""
if not isinstance(obj, tuple):
return False
for item in obj:
if not isinstance(item, torch.Tensor):
return False
return True
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework import LayerFilterParams
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
AttachTpcToFramework
from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS

if FOUND_SONY_CUSTOM_LAYERS:
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess

if version.parse(tf.__version__) >= version.parse("2.13"):
from keras.src.layers import Conv2D, DepthwiseConv2D, Dense, Reshape, ZeroPadding2D, Dropout, \
Expand Down Expand Up @@ -93,6 +91,7 @@ def __init__(self):
OperatorSetNames.TOPK: [tf.nn.top_k],
OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
OperatorSetNames.BOX_DECODE: [], # no such operator in keras
OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
OperatorSetNames.CAST: [tf.cast],
OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],
Expand All @@ -102,15 +101,9 @@ def __init__(self):
OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
OperatorSetNames.L2NORM: [tf.math.l2_normalize],
OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess]
}

if FOUND_SONY_CUSTOM_LAYERS:
self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = [SSDPostProcess]
else:
# If Custom layers is not installed then we don't want the user to fail, but just ignore custom layers
# in the initialized framework TPC
self._opset2layer[OperatorSetNames.SSD_POST_PROCESS] = []

self._opset2attr_mapping = {
OperatorSetNames.CONV: {
KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attach2fw import \
AttachTpcToFramework
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.attribute_filter import Eq
from sony_custom_layers.pytorch import MulticlassNMS, MulticlassNMSWithIndices, FasterRCNNBoxDecode


class AttachTpcToPytorch(AttachTpcToFramework):
Expand Down Expand Up @@ -97,7 +98,8 @@ def __init__(self):
OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
Eq('p', 2) | Eq('p', None))],
OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [] # no such operator in pytorch
OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
OperatorSetNames.BOX_DECODE: [FasterRCNNBoxDecode]
}

pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
Expand Down
1 change: 0 additions & 1 deletion model_compression_toolkit/verify_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,3 @@
FOUND_TORCHVISION = importlib.util.find_spec("torchvision") is not None
FOUND_ONNX = importlib.util.find_spec("onnx") is not None
FOUND_ONNXRUNTIME = importlib.util.find_spec("onnxruntime") is not None
FOUND_SONY_CUSTOM_LAYERS = importlib.util.find_spec('sony_custom_layers') is not None
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ matplotlib<3.10.0
scipy
protobuf
mct-quantizers==1.5.2
pydantic<2.0
pydantic<2.0
sony-custom-layers==0.4.0
Loading