11import torch
2- import os
2+ import argparse
33
44from mlir import ir
55from mlir .dialects import transform
@@ -65,7 +65,8 @@ def create_schedule(ctx: ir.Context) -> ir.Module:
6565 func = structured .MatchOp .match_op_names (
6666 named_seq .bodyTarget , ["func.func" ]
6767 )
68- # Use C interface wrappers - required to make function executable after jitting.
68+ # Use C interface wrappers - required to make function executable
69+ # after jitting.
6970 func = transform .apply_registered_pass (
7071 anytype , func , "llvm-request-c-wrappers"
7172 )
@@ -126,7 +127,7 @@ def create_pass_pipeline(ctx: ir.Context) -> PassManager:
126127
127128
128129# The example's entry point.
129- def main ():
130+ def main (args ):
130131 ### Baseline computation ###
131132 # Create inputs.
132133 a = torch .randn (16 , 32 , dtype = torch .float32 )
@@ -149,26 +150,23 @@ def main():
149150 pm .run (kernel .operation )
150151
151152 ### Compilation ###
152- # External shared libraries, containing MLIR runner utilities, are generally
153- # required to execute the compiled module.
154- # In this case, MLIR runner utils libraries are expected:
155- # - libmlir_runner_utils.so
156- # - libmlir_c_runner_utils.so
153+ # Parse additional libraries if present.
157154 #
158- # Get paths to MLIR runner shared libraries through an environment variable.
155+ # External shared libraries, runtime utilities, might be needed to execute
156+ # the compiled module.
159157 # The execution engine requires full paths to the libraries.
160- # For example, the env variable can be set as:
161- # LIGHTHOUSE_SHARED_LIBS=$PATH_TO_LLVM/build/lib/lib1.so:$PATH_TO_LLVM/build/lib/lib2.so
162- mlir_libs = os . environ . get ( "LIGHTHOUSE_SHARED_LIBS" , default = "" ). split (": " )
158+ mlir_libs = []
159+ if args . shared_libs :
160+ mlir_libs += args . shared_libs . split (", " )
163161
164162 # JIT the kernel.
165163 eng = ExecutionEngine (kernel , opt_level = 2 , shared_libs = mlir_libs )
166164
167165 # Initialize the JIT engine.
168166 #
169- # The deferred initialization executes global constructors that might have been
170- # created by the module during engine creation (for example, when `gpu.module`
171- # is present) or registered afterwards.
167+ # The deferred initialization executes global constructors that might
168+ # have been created by the module during engine creation (for example,
169+ # when `gpu.module` is present) or registered afterwards.
172170 #
173171 # Initialization is not strictly necessary in this case.
174172 # However, it is a good practice to perform it regardless.
@@ -194,4 +192,21 @@ def main():
194192
195193
196194if __name__ == "__main__" :
197- main ()
195+ parser = argparse .ArgumentParser ()
196+
197+ # External shared libraries, runtime utilities, might be needed to
198+ # execute the compiled module.
199+ # For example, MLIR runner utils libraries such as:
200+ # - libmlir_runner_utils.so
201+ # - libmlir_c_runner_utils.so
202+ #
203+ # Full paths to the libraries should be provided.
204+ # For example:
205+ # --shared-libs=$LLVM_BUILD/lib/lib1.so,$LLVM_BUILD/lib/lib2.so
206+ parser .add_argument (
207+ "--shared-libs" ,
208+ type = str ,
209+ help = "Comma-separated list of libraries to link dynamically" ,
210+ )
211+ args = parser .parse_args ()
212+ main (args )
0 commit comments