Skip to content

Conversation

dchigarev
Copy link

@dchigarev dchigarev commented Sep 17, 2025

This PR modifies the torch-mlir/generate-mlir.py -> torch-mlir/py_src/main.py script to make it actually work. The script is now capable of converting a pytorch's nn.Model to a mlir module.

How pytorch models are usually being distributed

Usually when someone wants to download a pretrained model they end up with two files:

  1. Weights (model's state dictionary .pth or .pt)
  2. A python script/library that contains model's "architecture" (a definition of a model class)

Then the model's usage in a user's script would look like this:

from some_models_lib import custom_model

model = custom_model()
state_dict = torch.load("path/to/state_dict.pth")
model.load_state_dict(state_dict)

model.eval()
with torch.no_grad():
    res = model(...)

In order to "load" a pytorch model to the export script we need to accept its state and a python's entrypoint that instantiates the models class. Example:

--model-entrypoint torchvision.models:resnet
--model-entrypoint my_module:create_model
//
--model-state-path path/to/a/file.pth

There's no trivial way to automatically deduce which arguments (tensor shapes) a model expects, so torch-mlir (and so our scripts) requires a "sample argument" which is basically an empty tensor(s) that has a proper shape and a dtype. Example:

--model-entrypoint my_module:create_model
--model-state-path path/to/a/file.pth
--sample-shape 1,3,224,224,float32

The script may also take positional and keyword arguments to pass to the model's instantiation function, as well as a custom function to generate "sample" models arguments.

The full list script's of arguments is the following:

  • --model-entrypoint (required) Path to the model entrypoint, e.g. 'torchvision.models:resnet18' or '/path/to/model.py:build_model'
  • --model-state-path (optional since an entry-point function may already setup proper state) Path to a state file of the Torch model (usually has .pt or .pth extension).
  • --model-args (default: "[]") (optional) Positional arguments to pass to the model's entrypoint
  • --model-kwargs (default: "{}") (optional) Keyword arguments to pass to the model's entrypoint
  • --sample-shapes (optional) Tensor shapes/dtype that the 'forward' method of the model will be called with. Must be specified if '--sample-fn' is not given.
  • --sample-fn (optional) Path to a function that generates sample arguments for the model's 'forward' method.
  • --dialect {"torch", "linalg", "stablehlo", "tosa"}
  • --out-mlir (optional) Path to save the generated MLIR module

The need to accept a lot of python objects (entrypoint, args, kwargs) as a string argument makes the script fragile and error prone. There seems to be no other way though if we want the export script to be responsible for "loading" and exporting models.

Is there an alternative?

IREE and Blade went a different way and make a user responsible for instantiating a model in their own script. A user then calls an export function (provided by iree or blade library) on their instantiated model.

An example of a user's script to convert their torch model to mlir using iree

An example from iree's docs

# iree
import iree.turbine.aot as aot
import torch

# Define the `nn.Module` to export.
import torchvision.models as models

model= models.resnet18().eval()

# Export the program using the simple API.
example_arg = torch.randn(1,3,224,224)
export_output = aot.export(linear_module, example_arg)

if save_mlir:
   export_output.save_mlir(path)
else:
   export_output.compile() # compile to a binary

Lighthouse's export script can also be used this way. A user simply needs to import the generate_mlir function and pass their model there:

# PYTHONPATH=/home/user/lighthouse/ingress/Torch-MLIR/py_src:$PYTHONPATH
from export_lib import generate_mlir

model = create_model()
sample_arg = create_sample_arg()

mlir_module = generate_mlir(model, sample_arg)

if save_mlir:
    with open(path, "w") as file:
        file.write(str(mlir_module))
else:
    # TBD
    lighthouse_pipeline(mlir_module)

@@ -1,28 +1,54 @@
#!/usr/bin/env bash
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example:

./generate-mlir.sh -m torchvision.models:resnet18 -S 1,3,224,224,float32 -o now.mlir

Copy link
Contributor

@rolfmorel rolfmorel Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we include this in the repo somewhere. While the repo doesn't really have tests yet, and certainly no CI, having a working example of the cmdline interface to this is helpful (or even necessary).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to also have an in-repo example of how to invoke the conversion from inside a user script.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added examples/ folder with several use cases

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dchigarev for making progress on PyTorch ingress!

In general it all seems to make sense to me! My comments are on relatively small matters.

Having said that, I would say the majority of the PR is on enabling the cmdline interface, which I expect to also be the most contentious. Personally, I am not a fan of such interfaces and prefer the scripting approach. If other people are in favour though, I am not opposed for the code to be included.

Do you happen to have examples of similar cmdline interfaces being used for enabling PyTorch lowerings in other projects?

@@ -1,28 +1,54 @@
#!/usr/bin/env bash
Copy link
Contributor

@rolfmorel rolfmorel Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we include this in the repo somewhere. While the repo doesn't really have tests yet, and certainly no CI, having a working example of the cmdline interface to this is helpful (or even necessary).


entrypoint = load_callable_symbol(args.model_entrypoint)

model = entrypoint(*eval(args.model_args), **eval(args.model_kwargs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes - that's quite some eval.

I guess if we are to have a cmdline interface, there's not much to be done about it.

@@ -1,28 +1,54 @@
#!/usr/bin/env bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to also have an in-repo example of how to invoke the conversion from inside a user script.

echo "Generating MLIR for model '$MODEL' with dialect '$DIALECT'..."
python $SCRIPT_DIR/generate-mlir.py --model "$MODEL" --dialect "$DIALECT"
echo "Generating MLIR for model entrypoint '$MODEL' with dialect '$DIALECT'..."
python "$SCRIPT_DIR/generate-mlir.py" "${args[@]}"
Copy link
Contributor

@rolfmorel rolfmorel Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this entire bash script be folded into the Python script?

At this point I do not see the .sh giving much value. I guess it is necessary for entering the virtualenv, otherwise it's just a wrapper, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A brainwave: could the python script exec itself after it has set up the right environment variables, i.e. the ones which correspond to entering the virtualenv? Or more hacky: os.system("source .../bin/activate; python "+__file__.__path__) in case we detect not being in the venv, e.g. due to imports failing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the idea for a python script to deduce whether it was launched in a proper env and modifying it. I would say a user should be responsible for setting up a proper env before launching the python script (they can always call a bash version of the script that handles venvs for them).

Simplified the generate-mlir.sh script so that it only activates venv and forwards all the arguments to the python script.

@dchigarev
Copy link
Author

Do you happen to have examples of similar cmdline interfaces being used for enabling PyTorch lowerings in other projects?

@rolfmorel thanks for your time and feedback!

No, I haven't seen such cmdline approach anywhere (I wasn't looking to deep though). On the surface of IREE's and Blade's documentation I could only found the user-script approach. So even if they have a cmdline option, they don't seem to promote it very well.

@banach-space
Copy link

This is great, thank you so much for working on this 🙏🏻

I have a few high-level suggestions.

Keep this PR simple and restrict to the required minimum.

The cmdline interface looks complex and is merely a "wrapper" for the script logic. We can't avoid having a script, but we can avoid the cmdline interface. And, with a complex cmdline interface like this, I would wrap it into yet another script. My suggestion - drop the interface for now. This will allow us to focus on the core logic instead.

Consistent filenames and hyphenation.

generate-mlir.py vs py_src vs dummy_mlp_factory.py vs export_bash.py. LLVM seems to prefer - over _. Whichever one we choose, lets use it consistently.

Use doctoring consistently.

Lets use (function + module) docstrings consistently (instead of mixing docstring and plain Python comments starting with #).

Do we need all the Bash scripts?

There's seems to be a fair bit of duplication, e.g. export_bash.sh vs export_py.sh vs export.py. It's not clear to me what all the scripts do and whether we need them. My suggestion - less is more.

Naming.

This PR modifies the torch-mlir/generate-mlir.py -> torch-mlir/py_src/main.py

IIUC, generate-mlir.py was misleading - no MLIR is generated. Instead, the script "exports" MLIR, right? To me, a generator would be something like https://github.com/libxsmm/tpp-mlir/blob/main/tools/mlir-gen/mlir-gen.cpp.

While main.py is an improvement (i.e. not misleading), it's a bit too enigmatic - why not export.py? Or export-mlir-from-pytorch.mlir? Basically, something descriptive. That said, naming is hard 🤷🏻

Final thoughts.

Really fantastic to see this, just a bit concerned that this PR is trying to achieve too many things in one go. I recommend trimming it - I'd much rather focus on the core part and also make sure that we establish a consistent way of naming, structuring and implementing things.

I've some other, more specific comments inline.

Thanks again for working on this! 🙏🏻



def generate_mlir(model, sample_args, sample_kwargs=None, dialect="linalg"):
# Convert the Torch model to MLIR

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use docstrings consistently throughout this project?

@@ -0,0 +1,16 @@
#!/usr/bin/env bash

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,16 @@
#!/usr/bin/env bash

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DOCUMENTME - what is the purpose of this script and how do I use it?

import argparse
from export_lib.export import load_torch_model, generate_sample_args, generate_mlir

# Parse arguments for selecting which model to load and which MLIR dialect to generate

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use docstring consistently.

@@ -0,0 +1,98 @@
#!/usr/bin/env python3

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DOCUMENTME - what is the purpose of this script and how do I use it?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the main purpose of this file is to export to MLIR, then main -> export? or export-torch-to-mlir? Or some combination?

--out-mlir res.mlir
```

Look into `examples/` folder for more info.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add links to an example implementing #1 and #2.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this script won't generate MLIR - it will export to MLIR? Could you rename accordingly?

Also, why do we need a Bash wrapper for the Python script? Why isn't Python enough?

@Groverkss
Copy link
Member

Building wrapper scripts around torch-mlir is not scalable at all. torch-mlir is not a library to build things with, not a tool to build scripts around. The proper way of doing this is shipping fx_importer as part of bindings: #3 (ready for review) and then building export over it and ship it as part of the python package. I'm going to send a pr on building an aot export for torch and onnx around that today to give an idea of how it should be done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants