Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
736006c
build new package of graph_builder
Jun 18, 2025
e7c1f14
rename refine functions and remove transforms configs
Jun 23, 2025
777e99a
create base graph builder to wrap the convertion logic
Jun 24, 2025
0a03d63
use graph builder in ru data facade
Jun 24, 2025
631d3fe
add relu2pot flag in transform functions and use builder in pytorch p…
Jun 24, 2025
c2b6037
add fqc to graph and basic substitutions in transform_torch_graph
Jun 24, 2025
61cebd9
set prior info in nodes of torch graph
Jun 24, 2025
75d5d74
move torch collapsing and bn folding substitutions to new transform_g…
Jun 25, 2025
178eeb9
use graph builder in torch gptq facade
Jun 25, 2025
72c60b4
update relu_bound flag in keras builder
Jun 25, 2025
8a31f5c
pass the builder fw class to the core runner instead of doing it outs…
Jun 25, 2025
227f47f
use set_fqc in graph builder instead of directly fqc
Jun 25, 2025
b26646e
use graph builder in bn_info_collection_test
Jun 25, 2025
8114d98
handle tb writer in base graph builder
Jun 25, 2025
5922898
fix xquant usage of graph_prepare
Jun 25, 2025
fa06fc4
remove todo
Jun 25, 2025
d2ef1f6
change visability of graph builder methods since it is used in pruning
Jun 25, 2025
b9fa4d3
use float graph builder in pruning torch facade
Jun 25, 2025
11382f3
use float graph builder in pruning keras facade
Jun 25, 2025
6e6ca1c
replace the usage of read model function with graph builder function
Jun 25, 2025
943a844
move torch reader to new package
Jun 25, 2025
5a92331
move torch reader into new module
Jun 30, 2025
f59e6b8
adjust torch pytests
Jun 30, 2025
bc754cf
remove read_model function from graph_prep_runner
Jun 30, 2025
c9ab059
fix keras tests
Jul 1, 2025
ca043c0
adjust keras pytests
Jul 1, 2025
9c95d4e
move import of keras graph builder in ru data facade after tf check
Jul 1, 2025
6c4175c
move import of torch graph builder in ru data facade after torch check
Jul 1, 2025
87a1165
remove imports from graph builder init
Jul 1, 2025
aa24e0a
fix import convert_pytorch_model_to_graph
Jul 1, 2025
17adafe
fix import convert_keras_model_to_graph
Jul 1, 2025
9275ff5
fix imports for envs where torch and tf are not installed
Jul 1, 2025
8d95ceb
remove old read_model_to_graph import in test_assert_to_operation.py
Jul 1, 2025
ca66e05
adapt test ru data after removing graph_prep function with get_finali…
Jul 1, 2025
24e2c3c
merge with main
Jul 1, 2025
bc5a3ac
fix functions tests after merge removal of fw graph builder
Jul 1, 2025
f85850e
fix torch test hessian service
Jul 1, 2025
fff3207
fix missing import in test_weights_activation_split_substitution
Jul 1, 2025
86c5221
improve documentation and add licenses
Jul 2, 2025
289dac1
Merge branch 'main' into graph-prepare-refactor
reuvenperetz Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 0 additions & 79 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,6 @@ def is_tuple_of_tensors(self, obj: Any) -> bool:
f'framework\'s is_tuple_of_tensors method.') # pragma: no cover


@abstractmethod
def model_reader(self,
model: Any,
representative_data_gen: Callable) -> Graph:
"""
Convert a framework's model into a graph.
Args:
model: Framework's model.
representative_data_gen (Callable): Dataset used for calibration.

Returns:
Graph representing the input model.
"""
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s model_reader method.') # pragma: no cover

@abstractmethod
def model_builder(self,
graph: Graph,
Expand Down Expand Up @@ -214,45 +198,6 @@ def get_substitutions_channel_equalization(self,
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_channel_equalization method.') # pragma: no cover

@abstractmethod
def get_substitutions_prepare_graph(self) -> List[common.BaseSubstitution]:
"""

Returns: A list of the framework substitutions used to prepare the graph.

"""
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_prepare_graph method.') # pragma: no cover

@abstractmethod
def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
List[common.BaseSubstitution]:
"""

Args:
quant_config: Quantization configuration.

Returns: A list of the framework substitutions used before we collect statistics.

"""
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_pre_statistics_collection method.') # pragma: no cover

@abstractmethod
def get_linear_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: linear collapsing substitution
"""
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_linear_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_op2d_add_const_collapsing_substitution(self) -> common.BaseSubstitution:
"""
Returns: conv2d add const collapsing substitution
"""
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_op2d_add_const_collapsing_substitution method.') # pragma: no cover

@abstractmethod
def get_substitutions_statistics_correction(self, quant_config: QuantizationConfig) -> \
Expand All @@ -269,14 +214,6 @@ def get_substitutions_statistics_correction(self, quant_config: QuantizationConf
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_substitutions_statistics_correction method.') # pragma: no cover

@abstractmethod
def get_residual_collapsing_substitution(self) -> List[common.BaseSubstitution]:
"""
Returns: A list of the framework substitutions used for residual collapsing
"""
raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_residual_collapsing_substitution method.') # pragma: no cover


@abstractmethod
def get_substitutions_post_statistics_collection(self, quant_config: QuantizationConfig) -> List[
Expand Down Expand Up @@ -319,22 +256,6 @@ def get_substitutions_after_second_moment_correction(self, quant_config: Quantiz
f'framework\'s get_substitutions_after_second_moment_correction '
f'method.') # pragma: no cover

def get_node_prior_info(self, node: BaseNode,
graph: Graph) -> NodePriorInfo:
"""
Get a NodePriorInfo object for a node.

Args:
node: Node to get its prior info.
graph: Graph to check the next node type.

Returns:
NodePriorInfo with information about the node.
"""

raise NotImplementedError(f'{self.__class__.__name__} has to implement the '
f'framework\'s get_node_prior_info method.') # pragma: no cover

def count_node_for_mixed_precision_interest_points(self, node: BaseNode) -> bool:
"""
Returns whether a given node in considered as a potential interest point for mp metric computation purposes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
# limitations under the License.
# ==============================================================================
import copy
from typing import Callable, Any

from model_compression_toolkit.core.graph_prep_runner import get_finalized_graph

from model_compression_toolkit.core import ResourceUtilization, CoreConfig, QuantizationErrorMethod
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization_calculator import \
ResourceUtilizationCalculator, BitwidthMode, TargetInclusionCriterion
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
from model_compression_toolkit.target_platform_capabilities import FrameworkQuantizationCapabilities


def compute_resource_utilization_data(in_model: Any,
representative_data_gen: Callable,
def compute_resource_utilization_data(graph: Graph,
core_config: CoreConfig,
fqc: FrameworkQuantizationCapabilities,
fw_impl: FrameworkImplementation) -> ResourceUtilization:
Expand All @@ -33,8 +33,7 @@ def compute_resource_utilization_data(in_model: Any,
This can serve as a basis for defining target Resource Utilization for mixed precision search.

Args:
in_model: Model to build graph from (the model that intended to be quantized).
representative_data_gen: Dataset used for calibration.
graph: Graph that represents the model to compute its Resource Utilization data.
core_config: CoreConfig containing parameters of how the model should be quantized.
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
the attached framework operator's information.
Expand All @@ -50,14 +49,13 @@ def compute_resource_utilization_data(in_model: Any,
if core_config.quantization_config.weights_error_method == QuantizationErrorMethod.HMSE:
core_config.quantization_config.weights_error_method = QuantizationErrorMethod.MSE

transformed_graph = graph_preparation_runner(in_model,
representative_data_gen=representative_data_gen,
quantization_config=core_config.quantization_config,
fw_impl=fw_impl,
fqc=fqc,
bit_width_config=core_config.bit_width_config,
mixed_precision_enable=False,
running_gptq=False)
graph = get_finalized_graph(graph,
fqc=fqc,
quant_config=core_config.quantization_config,
bit_width_config=core_config.bit_width_config,
fw_impl=fw_impl,
mixed_precision_enable=False,
running_gptq=False)

ru_calculator = ResourceUtilizationCalculator(transformed_graph, fw_impl)
ru_calculator = ResourceUtilizationCalculator(graph, fw_impl)
return ru_calculator.compute_resource_utilization(TargetInclusionCriterion.AnyQuantizedNonFused, BitwidthMode.QDefaultSP)
113 changes: 4 additions & 109 deletions model_compression_toolkit/core/graph_prep_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ==============================================================================


from typing import Callable, Any

from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.quantization.bit_width_config import BitWidthConfig
Expand All @@ -25,69 +23,13 @@
from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import set_manual_bitwidth_config
from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute
from model_compression_toolkit.core.common.substitutions.linear_collapsing_substitution import \
linear_collapsing_substitute
from model_compression_toolkit.core.common.visualization.tensorboard_writer import TensorboardWriter
from model_compression_toolkit.quantization_preparation.load_fqc import load_fqc_configuration
from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.framework_quantization_capabilities import \
FrameworkQuantizationCapabilities
from model_compression_toolkit.logger import Logger


def graph_preparation_runner(in_model: Any,
representative_data_gen: Callable,
quantization_config: QuantizationConfig,
fw_impl: FrameworkImplementation,
fqc: FrameworkQuantizationCapabilities,
bit_width_config: BitWidthConfig = None,
tb_w: TensorboardWriter = None,
mixed_precision_enable: bool = False,
running_gptq: bool = False) -> Graph:
"""
Runs all required preparations in order to build a quantization graph from the given model,
quantization configuration and target platform specifications.
This runner include the following steps:
- Reading and building a graph from the given model.
- Setting quantization config to each relevant node in the graph.
- Apply all necessary substitutions to finalize the graph for quantization.

Args:
in_model (Any): Model to quantize.
representative_data_gen (Callable): Dataset used for calibration.
quantization_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be quantized.
fw_impl (FrameworkImplementation): FrameworkImplementation object with a specific framework methods implementation.
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that models the inference target platform and
the attached framework operator's information.
bit_width_config (BitWidthConfig): Config for bit-width selection. Defaults to None.
tb_w (TensorboardWriter): TensorboardWriter object for logging.
mixed_precision_enable (bool): is mixed precision enabled.
running_gptq (bool): Whether or not a GPTQ optimization is planned to run after the PTQ process.

Returns:
An internal graph representation of the input model.
"""

graph = read_model_to_graph(in_model,
representative_data_gen,
fqc,
fw_impl)

if tb_w is not None:
tb_w.add_graph(graph, 'initial_graph')

transformed_graph = get_finalized_graph(graph,
fqc,
quantization_config,
bit_width_config,
tb_w,
fw_impl,
mixed_precision_enable=mixed_precision_enable,
running_gptq=running_gptq)

return transformed_graph


def get_finalized_graph(initial_graph: Graph,
def get_finalized_graph(graph: Graph,
fqc: FrameworkQuantizationCapabilities,
quant_config: QuantizationConfig = DEFAULTCONFIG,
bit_width_config: BitWidthConfig = None,
Expand All @@ -100,7 +42,7 @@ def get_finalized_graph(initial_graph: Graph,
process. All future graph substitutions and operations that change the graph should be added to this method.

Args:
initial_graph (Graph): Graph to apply the changes to.
graph (Graph): Graph to apply the changes to.
fqc (FrameworkQuantizationCapabilities): FrameworkQuantizationCapabilities object that describes the desired inference target platform (includes fusing patterns MCT should handle).
quant_config (QuantizationConfig): QuantizationConfig containing parameters of how the model should be
quantized.
Expand All @@ -120,34 +62,9 @@ def get_finalized_graph(initial_graph: Graph,
"Note: This method may significantly increase runtime during the parameter search process.")

######################################
# Graph substitution (prepare graph)
# Add quantization configurations
######################################
graph = substitute(initial_graph, fw_impl.get_substitutions_prepare_graph())

if tb_w is not None:
tb_w.add_graph(graph, 'after_graph_preparation')

#########################################
# Set prior info to nodes
##########################################
for node in graph.nodes:
node.prior_info = fw_impl.get_node_prior_info(node=node,
graph=graph)

##################################################
# Graph substitution (pre statistics collection)
##################################################
transformed_graph = substitute(graph, fw_impl.get_substitutions_pre_statistics_collection(quant_config))
if quant_config.linear_collapsing:
transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_linear_collapsing_substitution())
transformed_graph = linear_collapsing_substitute(transformed_graph, fw_impl.get_op2d_add_const_collapsing_substitution())
if quant_config.residual_collapsing:
transformed_graph = substitute(transformed_graph, fw_impl.get_residual_collapsing_substitution())

if tb_w is not None:
tb_w.add_graph(transformed_graph, 'pre_statistics_collection_substitutions')

transformed_graph = load_fqc_configuration(transformed_graph, fqc)
transformed_graph = load_fqc_configuration(graph, fqc)

# filter candidates per manual config
if bit_width_config:
Expand Down Expand Up @@ -188,25 +105,3 @@ def update(qc):
return transformed_graph


def read_model_to_graph(in_model: Any,
representative_data_gen: Callable,
fqc: FrameworkQuantizationCapabilities,
fw_impl: FrameworkImplementation = None) -> Graph:

"""
Read a model into a graph object.

Args:
in_model: Model to optimize and prepare for quantization.
representative_data_gen: Dataset used for calibration.
fqc: FrameworkQuantizationCapabilities object that models the inference target platform and
the attached framework operator's information.
fw_impl: FrameworkImplementation object with a specific framework methods implementation.

Returns:
Graph object that represents the model.
"""
graph = fw_impl.model_reader(in_model,
representative_data_gen)
graph.set_fqc(fqc)
return graph
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import tensorflow as tf
from keras.models import Model
from model_compression_toolkit.core.common.graph.base_graph import OutTensor
from packaging import version

from model_compression_toolkit.core.common.back2framework.base_model_builder import BaseModelBuilder
Expand All @@ -24,12 +25,10 @@

if version.parse(tf.__version__) >= version.parse("2.13"):
from keras import Input
from keras.src.layers.core import TFOpLambda
from keras.src.engine.base_layer import TensorFlowOpLayer, Layer
from keras.src.engine.base_layer import Layer
else:
from keras import Input # pragma: no cover
from keras.layers.core import TFOpLambda # pragma: no cover
from keras.engine.base_layer import TensorFlowOpLayer, Layer # pragma: no cover
from keras.engine.base_layer import Layer # pragma: no cover

from typing import Any, Dict, List, Tuple, Callable
from tensorflow.python.util.object_identity import Reference as TFReference
Expand All @@ -38,7 +37,6 @@
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
from model_compression_toolkit.core.keras.back2framework.instance_builder import OperationHandler
from model_compression_toolkit.core.keras.reader.connectivity_handler import OutTensor
from mct_quantizers import KerasQuantizationWrapper

# In tf2.3 fake quant node is implemented as TensorFlowOpLayer, while in tf2.4 as TFOpLambda.
Expand Down
Loading