-
Notifications
You must be signed in to change notification settings - Fork 3
[ingress][torch-mlir][RFC] Initial version of fx-importer script using torch-mlir #4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,54 @@ | ||
#!/usr/bin/env bash | ||
|
||
|
||
# Command line argument for model to load and MLIR dialect to generate | ||
while getopts "m:d:" opt; do | ||
while getopts "m:d:s:a:k:S:f:o:" opt; do | ||
case $opt in | ||
m) | ||
MODEL=$OPTARG | ||
;; | ||
d) | ||
DIALECT=$OPTARG | ||
;; | ||
s) | ||
STATE_PATH=$OPTARG | ||
;; | ||
a) | ||
MODEL_ARGS=$OPTARG | ||
;; | ||
k) | ||
MODEL_KWARGS=$OPTARG | ||
;; | ||
S) | ||
SAMPLE_SHAPES=$OPTARG | ||
;; | ||
f) | ||
SAMPLE_FN=$OPTARG | ||
;; | ||
o) | ||
OUT_MLIR=$OPTARG | ||
;; | ||
*) | ||
echo "Usage: $0 [-m model] [-d dialect]" | ||
echo "Usage: $0 [-m model-entrypoint] [-d dialect] [-s state_path] [-a model_args] [-k model_kwargs] [-S sample_shapes] [-f sample_fn] [-o out_mlir]" | ||
exit 1 | ||
;; | ||
esac | ||
done | ||
|
||
if [ -z "$MODEL" ]; then | ||
echo "Model not specified. Please provide a model using -m option." | ||
echo "Model not specified. Please provide a model entrypoint using -m option (e.g. torchvision.models:resnet18)." | ||
exit 1 | ||
fi | ||
if [ -z "$DIALECT" ]; then | ||
DIALECT="linalg" | ||
fi | ||
|
||
# If neither sample shapes nor sample fn provided, the Python will error. | ||
# Give a friendly check here to fail early. | ||
if [ -z "$SAMPLE_SHAPES" ] && [ -z "$SAMPLE_FN" ]; then | ||
echo "Either -S sample_shapes or -f sample_fn must be provided." | ||
exit 1 | ||
fi | ||
|
||
# Enable local virtualenv created by install-virtualenv.sh | ||
if [ ! -d "torch-mlir-venv" ]; then | ||
echo "Virtual environment not found. Please run install-virtualenv.sh first." | ||
|
@@ -33,10 +59,19 @@ source torch-mlir-venv/bin/activate | |
# Find script directory | ||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")") | ||
|
||
# Build python arg list | ||
args=( "--model-entrypoint" "$MODEL" "--dialect" "$DIALECT" ) | ||
[ -n "$STATE_PATH" ] && args+=( "--model-state-path" "$STATE_PATH" ) | ||
[ -n "$MODEL_ARGS" ] && args+=( "--model-args" "$MODEL_ARGS" ) | ||
[ -n "$MODEL_KWARGS" ] && args+=( "--model-kwargs" "$MODEL_KWARGS" ) | ||
[ -n "$SAMPLE_SHAPES" ]&& args+=( "--sample-shapes" "$SAMPLE_SHAPES" ) | ||
[ -n "$SAMPLE_FN" ] && args+=( "--sample-fn" "$SAMPLE_FN" ) | ||
[ -n "$OUT_MLIR" ] && args+=( "--out-mlir" "$OUT_MLIR" ) | ||
|
||
# Use the Python script to generate MLIR | ||
echo "Generating MLIR for model '$MODEL' with dialect '$DIALECT'..." | ||
python $SCRIPT_DIR/generate-mlir.py --model "$MODEL" --dialect "$DIALECT" | ||
echo "Generating MLIR for model entrypoint '$MODEL' with dialect '$DIALECT'..." | ||
python "$SCRIPT_DIR/generate-mlir.py" "${args[@]}" | ||
|
||
if [ $? -ne 0 ]; then | ||
echo "Failed to generate MLIR for model '$MODEL'." | ||
exit 1 | ||
fi | ||
fi | ||
dchigarev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import importlib | ||
import importlib.util | ||
import sys | ||
import os | ||
|
||
import torch | ||
from torch._subclasses.fake_tensor import FakeTensorMode | ||
|
||
from typing import Callable | ||
|
||
|
||
def load_callable_symbol(entry: str) -> Callable: | ||
""" | ||
Load a callable python symbol from a module or a file. | ||
Parameters | ||
---------- | ||
entry : str | ||
A string specifying the module or file and the attribute path, | ||
in the format 'module_or_path:attr', e.g. | ||
'torchvision.models:resnet18' or '/path/to/model.py:build_model'. | ||
Returns | ||
------- | ||
Callable | ||
""" | ||
if ":" not in entry: | ||
raise ValueError("Entry must be like 'module_or_path:attr'") | ||
|
||
left, right = entry.split(":", 1) | ||
attr_path = right.split(".") | ||
|
||
if os.path.exists(left) and left.endswith(".py"): | ||
mod_dir = os.path.abspath(os.path.dirname(left)) | ||
mod_name = os.path.splitext(os.path.basename(left))[0] | ||
sys_path_was = list(sys.path) | ||
try: | ||
if mod_dir not in sys.path: | ||
sys.path.insert(0, mod_dir) | ||
spec = importlib.util.spec_from_file_location(mod_name, left) | ||
if spec is None or spec.loader is None: | ||
raise ImportError(f"Cannot load spec from {left}") | ||
module = importlib.util.module_from_spec(spec) | ||
spec.loader.exec_module(module) | ||
finally: | ||
sys.path = sys_path_was | ||
else: | ||
module = importlib.import_module(left) | ||
|
||
obj = module | ||
for name in attr_path: | ||
obj = getattr(obj, name) | ||
|
||
return obj | ||
|
||
|
||
def parse_shape_str(shape: str) -> tuple[tuple[int], torch.dtype]: | ||
""" | ||
Parse a shape string into a shape tuple and a torch dtype. | ||
Parameters | ||
---------- | ||
shape : str | ||
A string representing the shape and dtype, e.g. '1,3,224,224,float32'. | ||
""" | ||
components = shape.split(",") | ||
dchigarev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
shapes = components[:-1] | ||
dtype = components[-1] | ||
tdtype = getattr(torch, dtype) | ||
if tdtype is None: | ||
raise ValueError(f"Unsupported dtype: {dtype}") | ||
if any(dim == "?" for dim in shapes): | ||
dchigarev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
raise ValueError(f"Dynamic shapes are not supported yet: {shape}") | ||
return (tuple(int(dim) for dim in shapes if dim), tdtype) | ||
dchigarev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
|
||
def generate_fake_tensor(shape: tuple[int], dtype: torch.dtype) -> torch.Tensor: | ||
"""Generate a fake tensor (has no actual buffer) with the given shape and dtype.""" | ||
with FakeTensorMode(): | ||
return torch.empty(shape, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yikes - that's quite some
eval
.I guess if we are to have a cmdline interface, there's not much to be done about it.