Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 96 additions & 18 deletions ingress/Torch-MLIR/generate-mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,60 @@
import argparse
import os
import torch
import torch.nn as nn
from torch_mlir import fx
from torch_mlir.fx import OutputType

from utils import parse_shape_str, load_callable_symbol, generate_fake_tensor


# Parse arguments for selecting which model to load and which MLIR dialect to generate
def parse_args():
parser = argparse.ArgumentParser(description="Generate MLIR for Torch-MLIR models.")
parser.add_argument(
"--model",
"--model-entrypoint",
type=str,
required=True,
help="Path to the Torch model file.",
help="Path to the model entrypoint, e.g. 'torchvision.models:resnet18' or '/path/to/model.py:build_model'.",
)
parser.add_argument(
"--model-state-path",
type=str,
required=False,
help="Path to a state file of the Torch model (usually has .pt or .pth extension).",
)
parser.add_argument(
"--model-args",
type=str,
required=False,
default="[]",
help=""
"Positional arguments to pass to the model's entrypoint "
"(note that this argument will be passed to an 'eval',"
" so the string should contain a valid python code).",
)
parser.add_argument(
"--model-kwargs",
type=str,
required=False,
default="{}",
help=""
"Keyword arguments to pass to the model's entrypoint "
"(note that this argument will be passed to an 'eval',"
" so the string should contain a valid python code).",
)
parser.add_argument(
"--sample-shapes",
type=str,
required=False,
help="Tensor shapes/dtype that the 'forward' method of the model will be called with,"
" e.g. '1,3,224,224,float32'. Must be specified if '--sample-fn' is not given.",
)
parser.add_argument(
"--sample-fn",
type=str,
required=False,
help="Path to a function that generates sample arguments for the model's 'forward' method."
" The function should return a tuple of (args, kwargs). If this is given, '--sample-shapes' is ignored.",
)
parser.add_argument(
"--dialect",
Expand All @@ -23,50 +65,86 @@ def parse_args():
default="linalg",
help="MLIR dialect to generate.",
)
parser.add_argument(
"--out-mlir",
type=str,
required=False,
help="Path to save the generated MLIR module.",
)
return parser.parse_args()

# Functin to load the Torch model
def load_torch_model(model_path):

# Function to load the Torch model
def load_torch_model(model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file {model_path} does not exist.")

model = torch.load(model_path)
return model

# Function to generate MLIR from the Torch model
# See: https://github.com/MrSidims/PytorchExplorer/blob/main/backend/server.py#L237
def generate_mlir(model, dialect):

def generate_sample_args(shape_str, sample_fn_path) -> tuple[tuple, dict]:
"""
Generate sample arguments for the model's 'forward' method.
(Required by torch_mlir.fx.export_and_import)
"""
if sample_fn_path is None:
shape, dtype = parse_shape_str(shape_str)
return (generate_fake_tensor(shape, dtype),), {}

return load_callable_symbol(sample_fn_path)()


def generate_mlir(model, sample_args, sample_kwargs=None, dialect="linalg"):
# Convert the Torch model to MLIR
output_type = None
if dialect == "torch":
output_type = OutputType.TORCH
elif dialect == "linalg":
output_type = OutputType.LINALG
output_type = OutputType.LINALG_ON_TENSORS
elif dialect == "stablehlo":
output_type = OutputType.STABLEHLO
elif dialect == "tosa":
output_type = OutputType.TOSA
else:
raise ValueError(f"Unsupported dialect: {dialect}")

module = fx.export_and_import(model, "", output_type=output_type)
if sample_kwargs is None:
sample_kwargs = {}

model.eval()
module = fx.export_and_import(
model, *sample_args, output_type=output_type, **sample_kwargs
)
return module


# Main function to execute the script
def main():
args = parse_args()

# Load the Torch model
model = load_torch_model(args.model)

entrypoint = load_callable_symbol(args.model_entrypoint)

model = entrypoint(*eval(args.model_args), **eval(args.model_kwargs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes - that's quite some eval.

I guess if we are to have a cmdline interface, there's not much to be done about it.

if args.model_state_path is not None:
state_dict = load_torch_model(args.model_state_path)
model.load_state_dict(state_dict)

sample_args, sample_kwargs = generate_sample_args(
args.sample_shapes, args.sample_fn
)
# Generate MLIR from the model
mlir_module = generate_mlir(model, args.dialect)
mlir_module = generate_mlir(model, sample_args, sample_kwargs, args.dialect)

# Print or save the MLIR module
print(mlir_module)
if args.out_mlir:
with open(args.out_mlir, "w") as f:
f.write(str(mlir_module))
else:
print(mlir_module)


# Entry point for the script
if __name__ == "__main__":
main()
main()
47 changes: 41 additions & 6 deletions ingress/Torch-MLIR/generate-mlir.sh
Original file line number Diff line number Diff line change
@@ -1,28 +1,54 @@
#!/usr/bin/env bash
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

example:

./generate-mlir.sh -m torchvision.models:resnet18 -S 1,3,224,224,float32 -o now.mlir

Copy link
Contributor

@rolfmorel rolfmorel Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we include this in the repo somewhere. While the repo doesn't really have tests yet, and certainly no CI, having a working example of the cmdline interface to this is helpful (or even necessary).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to also have an in-repo example of how to invoke the conversion from inside a user script.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added examples/ folder with several use cases


# Command line argument for model to load and MLIR dialect to generate
while getopts "m:d:" opt; do
while getopts "m:d:s:a:k:S:f:o:" opt; do
case $opt in
m)
MODEL=$OPTARG
;;
d)
DIALECT=$OPTARG
;;
s)
STATE_PATH=$OPTARG
;;
a)
MODEL_ARGS=$OPTARG
;;
k)
MODEL_KWARGS=$OPTARG
;;
S)
SAMPLE_SHAPES=$OPTARG
;;
f)
SAMPLE_FN=$OPTARG
;;
o)
OUT_MLIR=$OPTARG
;;
*)
echo "Usage: $0 [-m model] [-d dialect]"
echo "Usage: $0 [-m model-entrypoint] [-d dialect] [-s state_path] [-a model_args] [-k model_kwargs] [-S sample_shapes] [-f sample_fn] [-o out_mlir]"
exit 1
;;
esac
done

if [ -z "$MODEL" ]; then
echo "Model not specified. Please provide a model using -m option."
echo "Model not specified. Please provide a model entrypoint using -m option (e.g. torchvision.models:resnet18)."
exit 1
fi
if [ -z "$DIALECT" ]; then
DIALECT="linalg"
fi

# If neither sample shapes nor sample fn provided, the Python will error.
# Give a friendly check here to fail early.
if [ -z "$SAMPLE_SHAPES" ] && [ -z "$SAMPLE_FN" ]; then
echo "Either -S sample_shapes or -f sample_fn must be provided."
exit 1
fi

# Enable local virtualenv created by install-virtualenv.sh
if [ ! -d "torch-mlir-venv" ]; then
echo "Virtual environment not found. Please run install-virtualenv.sh first."
Expand All @@ -33,10 +59,19 @@ source torch-mlir-venv/bin/activate
# Find script directory
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")

# Build python arg list
args=( "--model-entrypoint" "$MODEL" "--dialect" "$DIALECT" )
[ -n "$STATE_PATH" ] && args+=( "--model-state-path" "$STATE_PATH" )
[ -n "$MODEL_ARGS" ] && args+=( "--model-args" "$MODEL_ARGS" )
[ -n "$MODEL_KWARGS" ] && args+=( "--model-kwargs" "$MODEL_KWARGS" )
[ -n "$SAMPLE_SHAPES" ]&& args+=( "--sample-shapes" "$SAMPLE_SHAPES" )
[ -n "$SAMPLE_FN" ] && args+=( "--sample-fn" "$SAMPLE_FN" )
[ -n "$OUT_MLIR" ] && args+=( "--out-mlir" "$OUT_MLIR" )

# Use the Python script to generate MLIR
echo "Generating MLIR for model '$MODEL' with dialect '$DIALECT'..."
python $SCRIPT_DIR/generate-mlir.py --model "$MODEL" --dialect "$DIALECT"
echo "Generating MLIR for model entrypoint '$MODEL' with dialect '$DIALECT'..."
python "$SCRIPT_DIR/generate-mlir.py" "${args[@]}"
Copy link
Contributor

@rolfmorel rolfmorel Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this entire bash script be folded into the Python script?

At this point I do not see the .sh giving much value. I guess it is necessary for entering the virtualenv, otherwise it's just a wrapper, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A brainwave: could the python script exec itself after it has set up the right environment variables, i.e. the ones which correspond to entering the virtualenv? Or more hacky: os.system("source .../bin/activate; python "+__file__.__path__) in case we detect not being in the venv, e.g. due to imports failing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the idea for a python script to deduce whether it was launched in a proper env and modifying it. I would say a user should be responsible for setting up a proper env before launching the python script (they can always call a bash version of the script that handles venvs for them).

Simplified the generate-mlir.sh script so that it only activates venv and forwards all the arguments to the python script.

if [ $? -ne 0 ]; then
echo "Failed to generate MLIR for model '$MODEL'."
exit 1
fi
fi
80 changes: 80 additions & 0 deletions ingress/Torch-MLIR/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import importlib
import importlib.util
import sys
import os

import torch
from torch._subclasses.fake_tensor import FakeTensorMode

from typing import Callable


def load_callable_symbol(entry: str) -> Callable:
"""
Load a callable python symbol from a module or a file.
Parameters
----------
entry : str
A string specifying the module or file and the attribute path,
in the format 'module_or_path:attr', e.g.
'torchvision.models:resnet18' or '/path/to/model.py:build_model'.
Returns
-------
Callable
"""
if ":" not in entry:
raise ValueError("Entry must be like 'module_or_path:attr'")

left, right = entry.split(":", 1)
attr_path = right.split(".")

if os.path.exists(left) and left.endswith(".py"):
mod_dir = os.path.abspath(os.path.dirname(left))
mod_name = os.path.splitext(os.path.basename(left))[0]
sys_path_was = list(sys.path)
try:
if mod_dir not in sys.path:
sys.path.insert(0, mod_dir)
spec = importlib.util.spec_from_file_location(mod_name, left)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load spec from {left}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
finally:
sys.path = sys_path_was
else:
module = importlib.import_module(left)

obj = module
for name in attr_path:
obj = getattr(obj, name)

return obj


def parse_shape_str(shape: str) -> tuple[tuple[int], torch.dtype]:
"""
Parse a shape string into a shape tuple and a torch dtype.
Parameters
----------
shape : str
A string representing the shape and dtype, e.g. '1,3,224,224,float32'.
"""
components = shape.split(",")
shapes = components[:-1]
dtype = components[-1]
tdtype = getattr(torch, dtype)
if tdtype is None:
raise ValueError(f"Unsupported dtype: {dtype}")
if any(dim == "?" for dim in shapes):
raise ValueError(f"Dynamic shapes are not supported yet: {shape}")
return (tuple(int(dim) for dim in shapes if dim), tdtype)


def generate_fake_tensor(shape: tuple[int], dtype: torch.dtype) -> torch.Tensor:
"""Generate a fake tensor (has no actual buffer) with the given shape and dtype."""
with FakeTensorMode():
return torch.empty(shape, dtype=dtype)