Skip to content

Commit ba8aa78

Browse files
authored
[examples][mlir] Fix shared library path handling (#11)
Improves handling around the list of shared libraries paths to avoid execution engine errors caused by path resolution. Extra shared libraries can now be provided using `--shared-libs` script flag instead of environment variable. Example's documentation wording is also improved as the example kernel does not actively require any runtime utilities. However, the section is kept to showcase an example library setup as external runtime utilities are often needed for more complex kernels.
1 parent 2c9c0e0 commit ba8aa78

File tree

1 file changed

+31
-16
lines changed

1 file changed

+31
-16
lines changed

python/examples/mlir/compile_and_run.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
import os
2+
import argparse
33

44
from mlir import ir
55
from mlir.dialects import transform
@@ -65,7 +65,8 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
6565
func = structured.MatchOp.match_op_names(
6666
named_seq.bodyTarget, ["func.func"]
6767
)
68-
# Use C interface wrappers - required to make function executable after jitting.
68+
# Use C interface wrappers - required to make function executable
69+
# after jitting.
6970
func = transform.apply_registered_pass(
7071
anytype, func, "llvm-request-c-wrappers"
7172
)
@@ -126,7 +127,7 @@ def create_pass_pipeline(ctx: ir.Context) -> PassManager:
126127

127128

128129
# The example's entry point.
129-
def main():
130+
def main(args):
130131
### Baseline computation ###
131132
# Create inputs.
132133
a = torch.randn(16, 32, dtype=torch.float32)
@@ -149,26 +150,23 @@ def main():
149150
pm.run(kernel.operation)
150151

151152
### Compilation ###
152-
# External shared libraries, containing MLIR runner utilities, are generally
153-
# required to execute the compiled module.
154-
# In this case, MLIR runner utils libraries are expected:
155-
# - libmlir_runner_utils.so
156-
# - libmlir_c_runner_utils.so
153+
# Parse additional libraries if present.
157154
#
158-
# Get paths to MLIR runner shared libraries through an environment variable.
155+
# External shared libraries, runtime utilities, might be needed to execute
156+
# the compiled module.
159157
# The execution engine requires full paths to the libraries.
160-
# For example, the env variable can be set as:
161-
# LIGHTHOUSE_SHARED_LIBS=$PATH_TO_LLVM/build/lib/lib1.so:$PATH_TO_LLVM/build/lib/lib2.so
162-
mlir_libs = os.environ.get("LIGHTHOUSE_SHARED_LIBS", default="").split(":")
158+
mlir_libs = []
159+
if args.shared_libs:
160+
mlir_libs += args.shared_libs.split(",")
163161

164162
# JIT the kernel.
165163
eng = ExecutionEngine(kernel, opt_level=2, shared_libs=mlir_libs)
166164

167165
# Initialize the JIT engine.
168166
#
169-
# The deferred initialization executes global constructors that might have been
170-
# created by the module during engine creation (for example, when `gpu.module`
171-
# is present) or registered afterwards.
167+
# The deferred initialization executes global constructors that might
168+
# have been created by the module during engine creation (for example,
169+
# when `gpu.module` is present) or registered afterwards.
172170
#
173171
# Initialization is not strictly necessary in this case.
174172
# However, it is a good practice to perform it regardless.
@@ -194,4 +192,21 @@ def main():
194192

195193

196194
if __name__ == "__main__":
197-
main()
195+
parser = argparse.ArgumentParser()
196+
197+
# External shared libraries, runtime utilities, might be needed to
198+
# execute the compiled module.
199+
# For example, MLIR runner utils libraries such as:
200+
# - libmlir_runner_utils.so
201+
# - libmlir_c_runner_utils.so
202+
#
203+
# Full paths to the libraries should be provided.
204+
# For example:
205+
# --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so
206+
parser.add_argument(
207+
"--shared-libs",
208+
type=str,
209+
help="Comma-separated list of libraries to link dynamically",
210+
)
211+
args = parser.parse_args()
212+
main(args)

0 commit comments

Comments
 (0)