From 225a4f41a7371f244262c094b19fea3012ec38de Mon Sep 17 00:00:00 2001 From: dchigarev Date: Wed, 17 Sep 2025 12:16:10 +0000 Subject: [PATCH 1/4] [mlir][ingress][RFC] Initial version of fx-importer script using torch-mlir Signed-off-by: dchigarev --- ingress/Torch-MLIR/generate-mlir.py | 111 +++++++++++++++++++++++----- ingress/Torch-MLIR/generate-mlir.sh | 47 ++++++++++-- ingress/Torch-MLIR/utils.py | 80 ++++++++++++++++++++ 3 files changed, 214 insertions(+), 24 deletions(-) create mode 100644 ingress/Torch-MLIR/utils.py diff --git a/ingress/Torch-MLIR/generate-mlir.py b/ingress/Torch-MLIR/generate-mlir.py index 888e6dd..d84a165 100644 --- a/ingress/Torch-MLIR/generate-mlir.py +++ b/ingress/Torch-MLIR/generate-mlir.py @@ -3,18 +3,60 @@ import argparse import os import torch -import torch.nn as nn from torch_mlir import fx from torch_mlir.fx import OutputType +from utils import parse_shape_str, load_callable_symbol, generate_fake_tensor + + # Parse arguments for selecting which model to load and which MLIR dialect to generate def parse_args(): parser = argparse.ArgumentParser(description="Generate MLIR for Torch-MLIR models.") parser.add_argument( - "--model", + "--model-entrypoint", type=str, required=True, - help="Path to the Torch model file.", + help="Path to the model entrypoint, e.g. 'torchvision.models:resnet18' or '/path/to/model.py:build_model'.", + ) + parser.add_argument( + "--model-state-path", + type=str, + required=False, + help="Path to the state file of the Torch model (usually has .pt or .pth extension).", + ) + parser.add_argument( + "--model-args", + type=str, + required=False, + default="[]", + help="" + "Positional arguments to pass to the model-entry " + "(note that this argument will be passed to an 'eval'," + " so the string should contain a valid python code).", + ) + parser.add_argument( + "--model-kwargs", + type=str, + required=False, + default="{}", + help="" + "Keyword arguments to pass to the model-entry " + "(note that this argument will be passed to an 'eval'," + " so the string should contain a valid python code).", + ) + parser.add_argument( + "--sample-shapes", + type=str, + required=False, + help="Tensor shapes/dtype that the 'forward' method of the model will be called with," + " e.g. '1,3,224,224,float32'. Must be specified if '--sample-fn' is not given.", + ) + parser.add_argument( + "--sample-fn", + type=str, + required=False, + help="Path to a function that generates sample arguments for the model's 'forward' method." + " The function should return a tuple of (args, kwargs). If this is given, '--sample-shapes' is ignored.", ) parser.add_argument( "--dialect", @@ -23,27 +65,43 @@ def parse_args(): default="linalg", help="MLIR dialect to generate.", ) + parser.add_argument( + "--out-mlir", + type=str, + required=False, + help="Path to save the generated MLIR module.", + ) return parser.parse_args() -# Functin to load the Torch model -def load_torch_model(model_path): +# Function to load the Torch model +def load_torch_model(model_path): if not os.path.exists(model_path): raise FileNotFoundError(f"Model file {model_path} does not exist.") - + model = torch.load(model_path) return model -# Function to generate MLIR from the Torch model -# See: https://github.com/MrSidims/PytorchExplorer/blob/main/backend/server.py#L237 -def generate_mlir(model, dialect): +def generate_sample_args(shape_str, sample_fn_path) -> tuple[tuple, dict]: + """ + Generate sample arguments for the model's 'forward' method. + (Required by torch_mlir.fx.export_and_import) + """ + if sample_fn_path is None: + shape, dtype = parse_shape_str(shape_str) + return (generate_fake_tensor(shape, dtype),), {} + + return load_callable_symbol(sample_fn_path)() + + +def generate_mlir(model, dialect, sample_args, sample_kwargs): # Convert the Torch model to MLIR output_type = None if dialect == "torch": output_type = OutputType.TORCH elif dialect == "linalg": - output_type = OutputType.LINALG + output_type = OutputType.LINALG_ON_TENSORS elif dialect == "stablehlo": output_type = OutputType.STABLEHLO elif dialect == "tosa": @@ -51,22 +109,39 @@ def generate_mlir(model, dialect): else: raise ValueError(f"Unsupported dialect: {dialect}") - module = fx.export_and_import(model, "", output_type=output_type) + model.eval() + module = fx.export_and_import( + model, *sample_args, output_type=output_type, **sample_kwargs + ) return module + # Main function to execute the script def main(): args = parse_args() - + # Load the Torch model - model = load_torch_model(args.model) - + entrypoint = load_callable_symbol(args.model_entrypoint) + + model = entrypoint(*eval(args.model_args), **eval(args.model_kwargs)) + if args.model_state_path is not None: + state_dict = load_torch_model(args.model_state_path) + model.load_state_dict(state_dict) + + sample_args, sample_kwargs = generate_sample_args( + args.sample_shapes, args.sample_fn + ) # Generate MLIR from the model - mlir_module = generate_mlir(model, args.dialect) - + mlir_module = generate_mlir(model, args.dialect, sample_args, sample_kwargs) + # Print or save the MLIR module - print(mlir_module) + if args.out_mlir: + with open(args.out_mlir, "w") as f: + f.write(str(mlir_module)) + else: + print(mlir_module) + # Entry point for the script if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/ingress/Torch-MLIR/generate-mlir.sh b/ingress/Torch-MLIR/generate-mlir.sh index 0a079c6..1decc65 100755 --- a/ingress/Torch-MLIR/generate-mlir.sh +++ b/ingress/Torch-MLIR/generate-mlir.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash # Command line argument for model to load and MLIR dialect to generate -while getopts "m:d:" opt; do +while getopts "m:d:s:a:k:S:f:o:" opt; do case $opt in m) MODEL=$OPTARG @@ -9,20 +9,46 @@ while getopts "m:d:" opt; do d) DIALECT=$OPTARG ;; + s) + STATE_PATH=$OPTARG + ;; + a) + MODEL_ARGS=$OPTARG + ;; + k) + MODEL_KWARGS=$OPTARG + ;; + S) + SAMPLE_SHAPES=$OPTARG + ;; + f) + SAMPLE_FN=$OPTARG + ;; + o) + OUT_MLIR=$OPTARG + ;; *) - echo "Usage: $0 [-m model] [-d dialect]" + echo "Usage: $0 [-m model-entrypoint] [-d dialect] [-s state_path] [-a model_args] [-k model_kwargs] [-S sample_shapes] [-f sample_fn] [-o out_mlir]" exit 1 ;; esac done + if [ -z "$MODEL" ]; then - echo "Model not specified. Please provide a model using -m option." + echo "Model not specified. Please provide a model entrypoint using -m option (e.g. torchvision.models:resnet18)." exit 1 fi if [ -z "$DIALECT" ]; then DIALECT="linalg" fi +# If neither sample shapes nor sample fn provided, the Python will error. +# Give a friendly check here to fail early. +if [ -z "$SAMPLE_SHAPES" ] && [ -z "$SAMPLE_FN" ]; then + echo "Either -S sample_shapes or -f sample_fn must be provided." + exit 1 +fi + # Enable local virtualenv created by install-virtualenv.sh if [ ! -d "torch-mlir-venv" ]; then echo "Virtual environment not found. Please run install-virtualenv.sh first." @@ -33,10 +59,19 @@ source torch-mlir-venv/bin/activate # Find script directory SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +# Build python arg list +args=( "--model-entrypoint" "$MODEL" "--dialect" "$DIALECT" ) +[ -n "$STATE_PATH" ] && args+=( "--model-state-path" "$STATE_PATH" ) +[ -n "$MODEL_ARGS" ] && args+=( "--model-args" "$MODEL_ARGS" ) +[ -n "$MODEL_KWARGS" ] && args+=( "--model-kwargs" "$MODEL_KWARGS" ) +[ -n "$SAMPLE_SHAPES" ]&& args+=( "--sample-shapes" "$SAMPLE_SHAPES" ) +[ -n "$SAMPLE_FN" ] && args+=( "--sample-fn" "$SAMPLE_FN" ) +[ -n "$OUT_MLIR" ] && args+=( "--out-mlir" "$OUT_MLIR" ) + # Use the Python script to generate MLIR -echo "Generating MLIR for model '$MODEL' with dialect '$DIALECT'..." -python $SCRIPT_DIR/generate-mlir.py --model "$MODEL" --dialect "$DIALECT" +echo "Generating MLIR for model entrypoint '$MODEL' with dialect '$DIALECT'..." +python "$SCRIPT_DIR/generate-mlir.py" "${args[@]}" if [ $? -ne 0 ]; then echo "Failed to generate MLIR for model '$MODEL'." exit 1 -fi +fi \ No newline at end of file diff --git a/ingress/Torch-MLIR/utils.py b/ingress/Torch-MLIR/utils.py new file mode 100644 index 0000000..a574c06 --- /dev/null +++ b/ingress/Torch-MLIR/utils.py @@ -0,0 +1,80 @@ +import importlib +import importlib.util +import sys +import os + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode + +from typing import Callable + + +def load_callable_symbol(entry: str) -> Callable: + """ + Load a callable python symbol from a module or a file. + + Parameters + ---------- + entry : str + A string specifying the module or file and the attribute path, + in the format 'module_or_path:attr', e.g. + 'torchvision.models:resnet18' or '/path/to/model.py:build_model'. + + Returns + ------- + Callable + """ + if ":" not in entry: + raise ValueError("Entry must be like 'module_or_path:attr'") + + left, right = entry.split(":", 1) + attr_path = right.split(".") + + if os.path.exists(left) and left.endswith(".py"): + mod_dir = os.path.abspath(os.path.dirname(left)) + mod_name = os.path.splitext(os.path.basename(left))[0] + sys_path_was = list(sys.path) + try: + if mod_dir not in sys.path: + sys.path.insert(0, mod_dir) + spec = importlib.util.spec_from_file_location(mod_name, left) + if spec is None or spec.loader is None: + raise ImportError(f"Cannot load spec from {left}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + finally: + sys.path = sys_path_was + else: + module = importlib.import_module(left) + + obj = module + for name in attr_path: + obj = getattr(obj, name) + + return obj + + +def parse_shape_str(shape: str) -> tuple[tuple[int], torch.dtype]: + """ + Parse a shape string into a shape tuple and a torch dtype. + + Parameters + ---------- + shape : str + A string representing the shape and dtype, e.g. '1,3,224,224,float32'. + """ + components = shape.split(",") + shapes = components[:-1] + dtype = components[-1] + tdtype = getattr(torch, dtype) + if tdtype is None: + raise ValueError(f"Unsupported dtype: {dtype}") + if any(dim == "?" for dim in shapes): + raise ValueError(f"Dynamic shapes are not supported yet: {shape}") + return (tuple(int(dim) for dim in shapes if dim), tdtype) + + +def generate_fake_tensor(shape: tuple[int], dtype: torch.dtype) -> torch.Tensor: + """Generate a fake tensor (has no actual buffer) with the given shape and dtype.""" + with FakeTensorMode(): + return torch.empty(shape, dtype=dtype) From 20ca817943d82cbc2c0befc086c61a543f032103 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Wed, 17 Sep 2025 14:07:32 +0000 Subject: [PATCH 2/4] fix typos Signed-off-by: dchigarev --- ingress/Torch-MLIR/generate-mlir.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ingress/Torch-MLIR/generate-mlir.py b/ingress/Torch-MLIR/generate-mlir.py index d84a165..9467076 100644 --- a/ingress/Torch-MLIR/generate-mlir.py +++ b/ingress/Torch-MLIR/generate-mlir.py @@ -22,7 +22,7 @@ def parse_args(): "--model-state-path", type=str, required=False, - help="Path to the state file of the Torch model (usually has .pt or .pth extension).", + help="Path to a state file of the Torch model (usually has .pt or .pth extension).", ) parser.add_argument( "--model-args", @@ -30,7 +30,7 @@ def parse_args(): required=False, default="[]", help="" - "Positional arguments to pass to the model-entry " + "Positional arguments to pass to the model's entrypoint " "(note that this argument will be passed to an 'eval'," " so the string should contain a valid python code).", ) @@ -40,7 +40,7 @@ def parse_args(): required=False, default="{}", help="" - "Keyword arguments to pass to the model-entry " + "Keyword arguments to pass to the model's entrypoint " "(note that this argument will be passed to an 'eval'," " so the string should contain a valid python code).", ) @@ -95,7 +95,7 @@ def generate_sample_args(shape_str, sample_fn_path) -> tuple[tuple, dict]: return load_callable_symbol(sample_fn_path)() -def generate_mlir(model, dialect, sample_args, sample_kwargs): +def generate_mlir(model, sample_args, sample_kwargs=None, dialect="linalg"): # Convert the Torch model to MLIR output_type = None if dialect == "torch": @@ -109,6 +109,9 @@ def generate_mlir(model, dialect, sample_args, sample_kwargs): else: raise ValueError(f"Unsupported dialect: {dialect}") + if sample_kwargs is None: + sample_kwargs = {} + model.eval() module = fx.export_and_import( model, *sample_args, output_type=output_type, **sample_kwargs @@ -132,7 +135,7 @@ def main(): args.sample_shapes, args.sample_fn ) # Generate MLIR from the model - mlir_module = generate_mlir(model, args.dialect, sample_args, sample_kwargs) + mlir_module = generate_mlir(model, sample_args, sample_kwargs, args.dialect) # Print or save the MLIR module if args.out_mlir: From 499db16d8aa21d366b70670ede1c6dae28745097 Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 25 Sep 2025 09:48:58 +0000 Subject: [PATCH 3/4] Restructure ingress/torch-mlir Signed-off-by: dchigarev --- .gitignore | 2 + ingress/Torch-MLIR/README.md | 49 +++++++++++ .../examples/dummy_mlp_cli/dummy_mlp.pth | Bin 0 -> 4209 bytes .../dummy_mlp_cli/dummy_mlp_factory.py | 23 ++++++ .../examples/dummy_mlp_cli/export_bash.sh | 10 +++ .../examples/dummy_mlp_cli/export_py.sh | 10 +++ .../examples/dummy_mlp_python/export.py | 25 ++++++ .../examples/dummy_mlp_python/run.sh | 4 + .../examples/resnet_18_cli/export_bash.sh | 9 ++ .../examples/resnet_18_cli/export_py.sh | 9 ++ ingress/Torch-MLIR/generate-mlir.sh | 77 ------------------ ingress/Torch-MLIR/py_src/__init__.py | 0 .../Torch-MLIR/py_src/export_lib/__init__.py | 8 ++ .../Torch-MLIR/py_src/export_lib/export.py | 58 +++++++++++++ .../{ => py_src/export_lib}/utils.py | 0 .../{generate-mlir.py => py_src/main.py} | 66 ++------------- ingress/Torch-MLIR/scripts/generate-mlir.sh | 16 ++++ .../{ => scripts}/install-virtualenv.sh | 3 +- 18 files changed, 232 insertions(+), 137 deletions(-) create mode 100644 .gitignore create mode 100644 ingress/Torch-MLIR/README.md create mode 100644 ingress/Torch-MLIR/examples/dummy_mlp_cli/dummy_mlp.pth create mode 100644 ingress/Torch-MLIR/examples/dummy_mlp_cli/dummy_mlp_factory.py create mode 100755 ingress/Torch-MLIR/examples/dummy_mlp_cli/export_bash.sh create mode 100755 ingress/Torch-MLIR/examples/dummy_mlp_cli/export_py.sh create mode 100644 ingress/Torch-MLIR/examples/dummy_mlp_python/export.py create mode 100755 ingress/Torch-MLIR/examples/dummy_mlp_python/run.sh create mode 100755 ingress/Torch-MLIR/examples/resnet_18_cli/export_bash.sh create mode 100755 ingress/Torch-MLIR/examples/resnet_18_cli/export_py.sh delete mode 100755 ingress/Torch-MLIR/generate-mlir.sh create mode 100644 ingress/Torch-MLIR/py_src/__init__.py create mode 100644 ingress/Torch-MLIR/py_src/export_lib/__init__.py create mode 100644 ingress/Torch-MLIR/py_src/export_lib/export.py rename ingress/Torch-MLIR/{ => py_src/export_lib}/utils.py (100%) rename ingress/Torch-MLIR/{generate-mlir.py => py_src/main.py} (61%) create mode 100755 ingress/Torch-MLIR/scripts/generate-mlir.sh rename ingress/Torch-MLIR/{ => scripts}/install-virtualenv.sh (94%) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fb2f9f0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +ingress/Torch-MLIR/examples/**/dumps/*.mlir diff --git a/ingress/Torch-MLIR/README.md b/ingress/Torch-MLIR/README.md new file mode 100644 index 0000000..2b35616 --- /dev/null +++ b/ingress/Torch-MLIR/README.md @@ -0,0 +1,49 @@ +Using scripts in this directory one can convert a Torch Model to a MLIR module. + +The conversion script is written in python and is basically a wrapper around [`torch-mlir` library](https://github.com/llvm/torch-mlir). One need to setup a python virtual environment with torch and torch-mlir libraries +(`./scripts/install-virtualenv.sh`) to use the script. + +In order to convert a model the script has to recieve: +1. An instance of `torch.nn.Model` with proper state (weights). +2. Sample input arguments to the model (e.g. empty tensor with proper shape and dtype). + +There are two options of how this info can be provided to the converter: + +### 1. Instantiate a model in your own script and use a function from the `py_src/export_lib` (recomended) + +In this scenario a user is responsible for instantiating a model with proper state in their +own python script. Then they should import a `generate_mlir` function from `py_src.export_lib` +and call it in order to get a MLIR module: + +```python +model : nn.Model = get_model() +sample_args = (get_sample_tensor(),) + +# PYTHONPATH=$(pwd)/py_src/ +from export_lib import generate_mlir + +mlir_module = generate_mlir(model, sample_args, dialect="linalg") +print(mlir_module) +``` + +### 2. Use `py_src/main.py` or `scripts/generate-mlir.sh` and pass Torch Model parameters via CLI + +In this scenario the `py_src/main.py` script is fully responsible for instantiating a torch model +and converting it to MLIR. User has to pass a proper python entrypoint for model's factory, +its parameters if needed (`--model-args & --model-kwargs`), and sample model arguments (either +as `--sample-shapes` or as an entrypoint to a function returning args and kwargs `--sample-fn`). + +``` +# note that 'my_module' has to be in $PYTHONPATH +python py_src/main.py --model-entrypoint my_module:my_factory \ + --module-state-path path/to/state.pth \ + --sample-shapes '1,2,324,float32' \ + --out-mlir res.mlir + +# note that 'my_module' has to be in $PYTHONPATH +./scripts/generate-mlir.sh --model-entrypoint torchvision.models:resnet18 \ + --sample-fn my_module:generate_resnet18_sample_args \ + --out-mlir res.mlir +``` + +Look into `examples/` folder for more info. diff --git a/ingress/Torch-MLIR/examples/dummy_mlp_cli/dummy_mlp.pth b/ingress/Torch-MLIR/examples/dummy_mlp_cli/dummy_mlp.pth new file mode 100644 index 0000000000000000000000000000000000000000..28ccd45f0c2b90cea7e86098773ab419177f2d5f GIT binary patch literal 4209 zcmbtX2~<=^7Hw9sK?O`u3Vz;&JSS#F-chWXSBxD`oucOaBz{ZPnuSRmYaY*~>f)39R2iH`n=V%A z)C)2++5{D+PDxMIsg}sZtc90XOfMdn&c7MAP|aF)LBtAeY8)5cwJ2t-V$Aq=B&>Ce z1^4vhW;G;S`_`en2JrfX7GlFdp@*JP|Lo0WI2Wu!0gi&XZq@vKP9u`%bfiJy@#%`0b-)F`j~ z5>s`_ah&R@Iz{Lxlk_smwBP&?8uk%ieg!ys-{YAt`z=BLpohlp`MQHG=+kF-dt%l>YYC(gZv^rb! z_J#R!4fZ0&NQ00s{t`uK(C-#-$32`JPbu`D1wm>hhfLad06G3C!&^j3}>7dL%2(|;ZDRXX7a`rbY}Bpe8jB7k47)X zvuE8V#-?^yvG(V4pGt z2cE73tu-+JJUkZ7=10T5TO(lQ$8q3xXdTWEnS?_liXdQtn7JIQ!+)=;G7O*lA(QEK zh8!%5hltX(ukk$tz%d|G{MTrW`WM`CT{?)(7w0Xrjt2Cm1;(7*zxF z$jM{}u&%YoYdeBrw@)gbSeOeLWyhJ_CpY7z`h4(Bm=6mZ;>mhTN7ylK6{x@SHtf|O zB=XW?2wGhQ)|@{S*WM?j)t}sHn+WRa)M4!>=SM`CJ7pW0}&eHaDrgorWK@eyslLCJfy?)ca0^Q1Jz6OYbpBj@+6gbMvv zRNtl`jkj3fjwL`#h6;<*Z;&RtUhu(3vT;q#wNNo|6V?v)fWPhCPg?JI;DUnB$GKd* z!kp{K#ZMw6WZ$oDFupN~sLQ4>KD)y~)0hie-X9C2fAoR!Bn1}b_D0{DOe`QFQ1Fc( zt{+%{!}1%+qaryt4Qe10rsQDOyVZEEIM>jEdU$j10opiq1qLpQ1*f!0I9E9x8nKf8 z)K-BNbGMRNS?}P^3@s$|vw^Cw`kAWOdquB)uE<-KV*CawrBXXWG23IaC1+lWP=Kl#Qf6sL-| zz=7q(^y4o&3;~%EGTyHoinghtz_1QH_bnh;%9NwB}Y4G@V6CP|Pq^6}D3qP%e z@);%A&IT|w&50<_uEcQh23*>*2psRlf-Y1Gt-k9Zck?1Rvid_B*8YIL?pH-GEmP8V z_cEsZVHmhP$;353@29cRv9P#rDSXb&G+4ORka|W!!V3wMJ(+=eT@~o-a-n0!P`t87 z1tZQELUdvgJg8cQYePIys{u^>b~m}EyT}|~vWrN3tzrF@wKU|@TC}vcB5(3{$3*W) zJjM8+SaOmIH}Bba@696N`^nG$b@PVyym_UHQ25<9Z`T(f@AM{Ydh-dnGjt?XUz=kX z*m0KD?Z0c7J#PrrU)>064=2FX#R||5sf6jD_`#MP@b% z{!dzX;7ShfzA#_Of7u`R|NE+y^^eS?5MTY%*j)J>`j5B$upw#({nzz6`f$`XI{Eq` zYCYroakn2xF>%lj)Yse+ckP(VOgZj`M?9Tiwud!dvSexW)!{I&?_g}dX^F#QUH;Z2$w2I8_`?JzUZkv?2uj!}#qy_tD}v3zoZSk1Xc@&@=r<^V~sFU=$e zPRr0U_&kj*I8HQf)znInL5ChFA+t)36Y0+F^pfqk`h>Z9V)OHLL$yrBFRiIJ~qOPRGJBbyE`xA+SyU~!(_bJ zhxNo8m%z>@*cjf tuple[tuple, dict]: + """ + Generate sample arguments for the model's 'forward' method. + (Required by torch_mlir.fx.export_and_import) + """ + if sample_fn_path is None: + shape, dtype = parse_shape_str(shape_str) + return (generate_fake_tensor(shape, dtype),), {} + + return load_callable_symbol(sample_fn_path)() + + +def generate_mlir(model, sample_args, sample_kwargs=None, dialect="linalg"): + # Convert the Torch model to MLIR + output_type = None + if dialect == "torch": + output_type = OutputType.TORCH + elif dialect == "linalg": + output_type = OutputType.LINALG_ON_TENSORS + elif dialect == "stablehlo": + output_type = OutputType.STABLEHLO + elif dialect == "tosa": + output_type = OutputType.TOSA + else: + raise ValueError(f"Unsupported dialect: {dialect}") + + if sample_kwargs is None: + sample_kwargs = {} + + model.eval() + module = fx.export_and_import( + model, *sample_args, output_type=output_type, **sample_kwargs + ) + return module diff --git a/ingress/Torch-MLIR/utils.py b/ingress/Torch-MLIR/py_src/export_lib/utils.py similarity index 100% rename from ingress/Torch-MLIR/utils.py rename to ingress/Torch-MLIR/py_src/export_lib/utils.py diff --git a/ingress/Torch-MLIR/generate-mlir.py b/ingress/Torch-MLIR/py_src/main.py similarity index 61% rename from ingress/Torch-MLIR/generate-mlir.py rename to ingress/Torch-MLIR/py_src/main.py index 9467076..feddf4f 100644 --- a/ingress/Torch-MLIR/generate-mlir.py +++ b/ingress/Torch-MLIR/py_src/main.py @@ -1,13 +1,7 @@ #!/usr/bin/env python3 import argparse -import os -import torch -from torch_mlir import fx -from torch_mlir.fx import OutputType - -from utils import parse_shape_str, load_callable_symbol, generate_fake_tensor - +from export_lib.export import load_torch_model, generate_sample_args, generate_mlir # Parse arguments for selecting which model to load and which MLIR dialect to generate def parse_args(): @@ -74,63 +68,17 @@ def parse_args(): return parser.parse_args() -# Function to load the Torch model -def load_torch_model(model_path): - if not os.path.exists(model_path): - raise FileNotFoundError(f"Model file {model_path} does not exist.") - - model = torch.load(model_path) - return model - - -def generate_sample_args(shape_str, sample_fn_path) -> tuple[tuple, dict]: - """ - Generate sample arguments for the model's 'forward' method. - (Required by torch_mlir.fx.export_and_import) - """ - if sample_fn_path is None: - shape, dtype = parse_shape_str(shape_str) - return (generate_fake_tensor(shape, dtype),), {} - - return load_callable_symbol(sample_fn_path)() - - -def generate_mlir(model, sample_args, sample_kwargs=None, dialect="linalg"): - # Convert the Torch model to MLIR - output_type = None - if dialect == "torch": - output_type = OutputType.TORCH - elif dialect == "linalg": - output_type = OutputType.LINALG_ON_TENSORS - elif dialect == "stablehlo": - output_type = OutputType.STABLEHLO - elif dialect == "tosa": - output_type = OutputType.TOSA - else: - raise ValueError(f"Unsupported dialect: {dialect}") - - if sample_kwargs is None: - sample_kwargs = {} - - model.eval() - module = fx.export_and_import( - model, *sample_args, output_type=output_type, **sample_kwargs - ) - return module - - # Main function to execute the script def main(): args = parse_args() # Load the Torch model - entrypoint = load_callable_symbol(args.model_entrypoint) - - model = entrypoint(*eval(args.model_args), **eval(args.model_kwargs)) - if args.model_state_path is not None: - state_dict = load_torch_model(args.model_state_path) - model.load_state_dict(state_dict) - + model = load_torch_model( + args.model_entrypoint, + args.model_state_path, + *eval(args.model_args), + **eval(args.model_kwargs) + ) sample_args, sample_kwargs = generate_sample_args( args.sample_shapes, args.sample_fn ) diff --git a/ingress/Torch-MLIR/scripts/generate-mlir.sh b/ingress/Torch-MLIR/scripts/generate-mlir.sh new file mode 100755 index 0000000..080cc05 --- /dev/null +++ b/ingress/Torch-MLIR/scripts/generate-mlir.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +# Find script directory +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +PY_SCRIPT_DIR=$SCRIPT_DIR/../py_src/ +VENV_DIR="$SCRIPT_DIR/../torch-mlir-venv" + +# Enable local virtualenv created by install-virtualenv.sh +if [ ! -d "$VENV_DIR" ]; then + echo "Virtual environment not found. Please run install-virtualenv.sh first." + exit 1 +fi +source $VENV_DIR/bin/activate + +# Use the Python script to generate MLIR +PYTHONPATH=$PYTHONPATH:$PY_SCRIPT_DIR python "$PY_SCRIPT_DIR/main.py" "$@" diff --git a/ingress/Torch-MLIR/install-virtualenv.sh b/ingress/Torch-MLIR/scripts/install-virtualenv.sh similarity index 94% rename from ingress/Torch-MLIR/install-virtualenv.sh rename to ingress/Torch-MLIR/scripts/install-virtualenv.sh index f48019c..43b9124 100755 --- a/ingress/Torch-MLIR/install-virtualenv.sh +++ b/ingress/Torch-MLIR/scripts/install-virtualenv.sh @@ -9,6 +9,7 @@ else DEVICE_TYPE=$(lspci | grep VGA) fi +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") # Install torch-mlir inside a virtual environment echo "First ensure uv is installed" @@ -16,7 +17,7 @@ echo "First ensure uv is installed" python -m pip install uv --upgrade echo "Preparing the virtual environment" -python -m uv venv torch-mlir-venv --python 3.12 +python -m uv venv $SCRIPT_DIR/../torch-mlir-venv --python 3.12 #echo "Preparing the virtual environment" #python3 -m venv torch-mlir-venv From a08bcc5fc6fb00ee8560bcc2d052f10bf44678fd Mon Sep 17 00:00:00 2001 From: dchigarev Date: Thu, 25 Sep 2025 09:51:11 +0000 Subject: [PATCH 4/4] address review suggestions Signed-off-by: dchigarev --- ingress/Torch-MLIR/py_src/export_lib/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ingress/Torch-MLIR/py_src/export_lib/utils.py b/ingress/Torch-MLIR/py_src/export_lib/utils.py index a574c06..0534e60 100644 --- a/ingress/Torch-MLIR/py_src/export_lib/utils.py +++ b/ingress/Torch-MLIR/py_src/export_lib/utils.py @@ -63,15 +63,13 @@ def parse_shape_str(shape: str) -> tuple[tuple[int], torch.dtype]: shape : str A string representing the shape and dtype, e.g. '1,3,224,224,float32'. """ - components = shape.split(",") - shapes = components[:-1] - dtype = components[-1] + *shapes, dtype = shape.split(",") tdtype = getattr(torch, dtype) if tdtype is None: raise ValueError(f"Unsupported dtype: {dtype}") - if any(dim == "?" for dim in shapes): + if "?" in shapes: raise ValueError(f"Dynamic shapes are not supported yet: {shape}") - return (tuple(int(dim) for dim in shapes if dim), tdtype) + return (tuple(map(int, shapes)), tdtype) def generate_fake_tensor(shape: tuple[int], dtype: torch.dtype) -> torch.Tensor: