-
Notifications
You must be signed in to change notification settings - Fork 3
[ingress][torch-mlir][RFC] Initial version of fx-importer script using torch-mlir #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
dchigarev
wants to merge
4
commits into
llvm:main
Choose a base branch
from
dchigarev:dchigarev/fx_importer
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
__pycache__ | ||
ingress/Torch-MLIR/examples/**/dumps/*.mlir |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
Using scripts in this directory one can convert a Torch Model to a MLIR module. | ||
|
||
The conversion script is written in python and is basically a wrapper around [`torch-mlir` library](https://github.com/llvm/torch-mlir). One need to setup a python virtual environment with torch and torch-mlir libraries | ||
(`./scripts/install-virtualenv.sh`) to use the script. | ||
|
||
In order to convert a model the script has to recieve: | ||
1. An instance of `torch.nn.Model` with proper state (weights). | ||
2. Sample input arguments to the model (e.g. empty tensor with proper shape and dtype). | ||
|
||
There are two options of how this info can be provided to the converter: | ||
|
||
### 1. Instantiate a model in your own script and use a function from the `py_src/export_lib` (recomended) | ||
|
||
In this scenario a user is responsible for instantiating a model with proper state in their | ||
own python script. Then they should import a `generate_mlir` function from `py_src.export_lib` | ||
and call it in order to get a MLIR module: | ||
|
||
```python | ||
model : nn.Model = get_model() | ||
sample_args = (get_sample_tensor(),) | ||
|
||
# PYTHONPATH=$(pwd)/py_src/ | ||
from export_lib import generate_mlir | ||
|
||
mlir_module = generate_mlir(model, sample_args, dialect="linalg") | ||
print(mlir_module) | ||
``` | ||
|
||
### 2. Use `py_src/main.py` or `scripts/generate-mlir.sh` and pass Torch Model parameters via CLI | ||
|
||
In this scenario the `py_src/main.py` script is fully responsible for instantiating a torch model | ||
and converting it to MLIR. User has to pass a proper python entrypoint for model's factory, | ||
its parameters if needed (`--model-args & --model-kwargs`), and sample model arguments (either | ||
as `--sample-shapes` or as an entrypoint to a function returning args and kwargs `--sample-fn`). | ||
|
||
``` | ||
# note that 'my_module' has to be in $PYTHONPATH | ||
python py_src/main.py --model-entrypoint my_module:my_factory \ | ||
--module-state-path path/to/state.pth \ | ||
--sample-shapes '1,2,324,float32' \ | ||
--out-mlir res.mlir | ||
|
||
# note that 'my_module' has to be in $PYTHONPATH | ||
./scripts/generate-mlir.sh --model-entrypoint torchvision.models:resnet18 \ | ||
--sample-fn my_module:generate_resnet18_sample_args \ | ||
--out-mlir res.mlir | ||
``` | ||
|
||
Look into `examples/` folder for more info. | ||
Binary file not shown.
23 changes: 23 additions & 0 deletions
23
ingress/Torch-MLIR/examples/dummy_mlp_cli/dummy_mlp_factory.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
import os | ||
|
||
class DummyMLP(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.Linear(10, 32), | ||
nn.ReLU(), | ||
nn.Linear(32, 2) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
def make_dummy_mlp(): | ||
return DummyMLP() | ||
|
||
if __name__ == "__main__": | ||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
torch.save(make_dummy_mlp().state_dict(), os.path.join(script_dir, "dummy_mlp.pth")) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/usr/bin/env bash | ||
|
||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" | ||
ROOT_DIR=$SCRIPT_DIR/../../scripts/ | ||
|
||
PYTHONPATH=$PYTHONPATH:$SCRIPT_DIR $ROOT_DIR/generate-mlir.sh --model-entrypoint dummy_mlp_factory:make_dummy_mlp \ | ||
--model-state-path $SCRIPT_DIR/dummy_mlp.pth \ | ||
--sample-shapes "1,10,float32" \ | ||
--dialect linalg \ | ||
--out-mlir $SCRIPT_DIR/dummy_mlp_sh.mlir |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#!/usr/bin/env bash | ||
|
||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" | ||
ROOT_DIR=$SCRIPT_DIR/../../py_src/ | ||
|
||
PYTHONPATH=$PYTHONPATH:$ROOT_DIR:$SCRIPT_DIR python $ROOT_DIR/main.py --model-entrypoint dummy_mlp_factory:make_dummy_mlp \ | ||
--model-state-path $SCRIPT_DIR/dummy_mlp.pth \ | ||
--sample-shapes "1,10,float32" \ | ||
--dialect linalg \ | ||
--out-mlir $SCRIPT_DIR/dummy_mlp.mlir |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from export_lib.export import generate_mlir | ||
|
||
class DummyMLP(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.net = nn.Sequential( | ||
nn.Linear(10, 32), | ||
nn.ReLU(), | ||
nn.Linear(32, 2) | ||
) | ||
|
||
def forward(self, x): | ||
return self.net(x) | ||
|
||
def main(): | ||
model = DummyMLP() | ||
dummy_input = torch.randn(1, 10) | ||
mlir_mod = generate_mlir(model, (dummy_input,), {}) | ||
print(mlir_mod) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" | ||
ROOT_DIR=$SCRIPT_DIR/../../py_src/ | ||
|
||
PYTHONPATH=$ROOT_DIR python $SCRIPT_DIR/export.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/usr/bin/env bash | ||
|
||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" | ||
ROOT_DIR=$SCRIPT_DIR/../../scripts/ | ||
|
||
$ROOT_DIR/generate-mlir.sh --model-entrypoint torchvision.models:resnet18 \ | ||
--sample-shapes "1,3,224,224,float32" \ | ||
--dialect linalg \ | ||
--out-mlir $SCRIPT_DIR/resnet_18_sh.mlir |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/usr/bin/env bash | ||
|
||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" | ||
ROOT_DIR=$SCRIPT_DIR/../../py_src/ | ||
|
||
python $ROOT_DIR/main.py --model-entrypoint torchvision.models:resnet18 \ | ||
--sample-shapes "1,3,224,224,float32" \ | ||
--dialect linalg \ | ||
--out-mlir $SCRIPT_DIR/resnet_18.mlir |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Examples package initialization | ||
""" | ||
Example scripts for Torch-MLIR usage. | ||
""" | ||
|
||
from .export import load_torch_model, generate_sample_args, generate_mlir | ||
|
||
__all__ = ['load_torch_model', 'generate_sample_args', 'generate_mlir'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import os | ||
import torch | ||
from torch_mlir import fx | ||
from torch_mlir.fx import OutputType | ||
|
||
from .utils import parse_shape_str, load_callable_symbol, generate_fake_tensor | ||
|
||
def load_torch_model(entrypoint_path, model_state_path=None, *args, **kwargs): | ||
entrypoint = load_callable_symbol(entrypoint_path) | ||
model = entrypoint(*args, **kwargs) | ||
if model_state_path is not None: | ||
state_dict = load_model_state(model_state_path) | ||
model.load_state_dict(state_dict) | ||
return model | ||
|
||
# Function to load the Torch model | ||
def load_model_state(model_path): | ||
if not os.path.exists(model_path): | ||
raise FileNotFoundError(f"Model file {model_path} does not exist.") | ||
|
||
model = torch.load(model_path) | ||
return model | ||
|
||
|
||
def generate_sample_args(shape_str, sample_fn_path) -> tuple[tuple, dict]: | ||
""" | ||
Generate sample arguments for the model's 'forward' method. | ||
(Required by torch_mlir.fx.export_and_import) | ||
""" | ||
if sample_fn_path is None: | ||
shape, dtype = parse_shape_str(shape_str) | ||
return (generate_fake_tensor(shape, dtype),), {} | ||
|
||
return load_callable_symbol(sample_fn_path)() | ||
|
||
|
||
def generate_mlir(model, sample_args, sample_kwargs=None, dialect="linalg"): | ||
# Convert the Torch model to MLIR | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we use docstrings consistently throughout this project? |
||
output_type = None | ||
if dialect == "torch": | ||
output_type = OutputType.TORCH | ||
elif dialect == "linalg": | ||
output_type = OutputType.LINALG_ON_TENSORS | ||
elif dialect == "stablehlo": | ||
output_type = OutputType.STABLEHLO | ||
elif dialect == "tosa": | ||
output_type = OutputType.TOSA | ||
else: | ||
raise ValueError(f"Unsupported dialect: {dialect}") | ||
|
||
if sample_kwargs is None: | ||
sample_kwargs = {} | ||
|
||
model.eval() | ||
module = fx.export_and_import( | ||
model, *sample_args, output_type=output_type, **sample_kwargs | ||
) | ||
return module |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add links to an example implementing #1 and #2.