Skip to content

Graph prepare refactor#1490

Closed
reuvenperetz wants to merge 40 commits intoSonySemiconductorSolutions:mainfrom
reuvenperetz:graph-prepare-refactor
Closed

Graph prepare refactor#1490
reuvenperetz wants to merge 40 commits intoSonySemiconductorSolutions:mainfrom
reuvenperetz:graph-prepare-refactor

Conversation

@reuvenperetz
Copy link
Copy Markdown
Contributor

Pull Request Description:

Extract the parts of graph building from the framework model to an outer package called "graph_builder".
BaseGraphBuilder defines API methods for converting the model to a graph and transforming it using the fundamental substitutions.

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).

reuvenp added 30 commits June 18, 2025 12:18
@@ -0,0 +1,14 @@
# Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that graph_preparation is a better name for the package

if tensorboard_writer is not None:
tensorboard_writer.add_graph(graph, 'initial_graph')

if fqc:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that fqc should be an argument to this module, since it is quantization-related.
If we must still store it in the graph, then it supposed to be the "quantization_preparation" responsibility

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started without it, but the relu_to_pot substitution requires it, which is why I added it eventually.
I'll do a different PR to remove this requirement first.

model: Any,
representative_dataset: Callable = None,
fqc: FrameworkQuantizationCapabilities = None,
tensorboard_writer: TensorboardWriter = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several things regarding the arguments:

  1. Move tensor board to be the last argument.
  2. Consider maybe gathering the 3 "structure-altering" flags in a dedicated config?
  3. Do we need the default values to not "break" API/tests? if not, then remove the default.
  4. Change the typehint of arguments that accept None to Optional[]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About 2: That's how I started, but I eventually thought it would cause unnecessary maintenance. I can write it back.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So depends on what the unnecessary maintanance is.
Normally, I wouldn't suggest to just have a config as an argument that is immidiatly flattend in the class init, but in this case they are all substitution bollean flags so it feels right to tie them together in a "which substitutions to run" config variable

import torch


def convert_pytorch_model_to_graph(model: torch.nn.Module,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why split it to a separate file and not include the implementation as part of PytorchGraphBuilder's convert_model_to_graph method? (same for keras)

from model_compression_toolkit.core.common.substitutions.apply_substitutions import substitute


def transform_pytorch_graph(graph: Graph,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question, why not as part of the pytorch builder class?

target_resource_utilization=target_resource_utilization,
tb_w=tb_w)
tb_w=tb_w,
fw_graph_builder=KerasGraphBuilder())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a todo either here (in each facade) or in the core_runner that this should eventually be initialized and run outside the runner.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to avoid duplicating code. Do you see any benefit in moving it outside rather than passing the framework class builder?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually, the runner is going to die. The point of separating each logical block from the core into a new module is exactly this - to not have 1 backbone that you have to run in order to get through anything in the MCT.
So, the endgame is to have each facade call the 4-5 modules that it needs to - construct graph, put quantization info, run some optimization, run any enhancement that is part of the facade's algorithmic framework.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarize my long answer - yes, I see benefites in moving it outside of the runner

from model_compression_toolkit.graph_builder.keras.convert_keras_model_to_graph import convert_keras_model_to_graph


class TestGraphReading(unittest.TestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As part of the PR where you impement tests for the new module (GraphBuilder), also migrate/remove this test and any other test that is meant to verify the graph building and transformation

representative_dataset: Callable = None,
fqc: FrameworkQuantizationCapabilities = None,
tensorboard_writer: TensorboardWriter = None,
linear_collapsing: bool = True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about the substitutions in get_substitutions_post_statistics_collection? aren't they graph preparation related? if so, need to think if they need to be part of this module or somewhere else

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They depend on the quantization params (this is why they are used today only after loading the fqc, and computing quantization params). So I do not think they should be part of it...

Copy link
Copy Markdown
Contributor

@ofirgo ofirgo Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's talk about this offline, I want us to take another look at this an figure out where they should land.

From a quick check this is what I think can work:

  1. softmax shift - not quantization related, just need to happen after statistic collection, so maybe put it in this module but not part of the "graph_builder" flow (as a separate api that can be called)
  2. concat threshold - this needs to happen after parameters selection, so it is like any other "quantization enhancement" substitution that we have (bc, snc...)
  3. input scaling - first, it is not available in pytorch, maybe we can earase it when limiting keras support. even if not - it is like the concat threshold - a quantization enhancement

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants