33from pathlib import Path
44from typing import Iterable , Mapping
55
6- from lighthouse .ingress .torch .utils import load_and_run_callable , maybe_load_and_run_callable
6+ from lighthouse .ingress .torch .utils import (
7+ load_and_run_callable ,
8+ maybe_load_and_run_callable ,
9+ )
710
811try :
912 import torch
2528
2629from mlir import ir
2730
31+
2832def import_from_model (
2933 model : nn .Module ,
3034 sample_args : Iterable ,
@@ -49,10 +53,10 @@ def import_from_model(
4953 ir_context (ir.Context, optional): An optional MLIR context to use for parsing
5054 the module. If not provided, the module is returned as a string.
5155 **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.
52-
56+
5357 Returns:
5458 str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.
55-
59+
5660 Examples:
5761 >>> import torch
5862 >>> import torch.nn as nn
@@ -61,17 +65,22 @@ def import_from_model(
6165 ... def __init__(self):
6266 ... super().__init__()
6367 ... self.fc = nn.Linear(10, 5)
68+ ...
6469 ... def forward(self, x):
6570 ... return self.fc(x)
6671 >>> model = SimpleModel()
6772 >>> sample_input = (torch.randn(1, 10),)
6873 >>> #
6974 >>> # option 1: get MLIR module as a string
70- >>> mlir_module : str = import_from_model(model, sample_input, dialect="linalg-on-tensors")
71- >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
75+ >>> mlir_module: str = import_from_model(
76+ ... model, sample_input, dialect="linalg-on-tensors"
77+ ... )
78+ >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
7279 >>> # option 2: get MLIR module as an ir.Module
7380 >>> ir_context = ir.Context()
74- >>> mlir_module_ir : ir.Module = import_from_model(model, sample_input, dialect="tosa", ir_context=ir_context)
81+ >>> mlir_module_ir: ir.Module = import_from_model(
82+ ... model, sample_input, dialect="tosa", ir_context=ir_context
83+ ... )
7584 """
7685 if dialect == "linalg" :
7786 raise ValueError (
@@ -134,45 +143,48 @@ def import_from_file(
134143 ir_context (ir.Context, optional): An optional MLIR context to use for parsing
135144 the module. If not provided, the module is returned as a string.
136145 **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.
137-
146+
138147 Returns:
139148 str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.
140-
149+
141150 Examples:
142151 Given a file `path/to/model_file.py` with the following content:
143152 ```python
144153 import torch
145154 import torch.nn as nn
146155
156+
147157 class MyModel(nn.Module):
148158 def __init__(self):
149159 super().__init__()
150160 self.fc = nn.Linear(10, 5)
161+
151162 def forward(self, x):
152163 return self.fc(x)
153164
165+
154166 def get_inputs():
155167 return (torch.randn(1, 10),)
156168 ```
157169
158170 The import script would look like:
159171 >>> from lighthouse.ingress.torch_import import import_from_file
160172 >>> # option 1: get MLIR module as a string
161- >>> mlir_module : str = import_from_file(
173+ >>> mlir_module: str = import_from_file(
162174 ... "path/to/model_file.py",
163175 ... model_class_name="MyModel",
164176 ... init_args_fn_name=None,
165- ... dialect="linalg-on-tensors"
177+ ... dialect="linalg-on-tensors",
166178 ... )
167- >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
179+ >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
168180 >>> # option 2: get MLIR module as an ir.Module
169181 >>> ir_context = ir.Context()
170- >>> mlir_module_ir : ir.Module = import_from_file(
182+ >>> mlir_module_ir: ir.Module = import_from_file(
171183 ... "path/to/model_file.py",
172184 ... model_class_name="MyModel",
173185 ... init_args_fn_name=None,
174186 ... dialect="linalg-on-tensors",
175- ... ir_context=ir_context
187+ ... ir_context=ir_context,
176188 ... )
177189 """
178190 if isinstance (filepath , str ):
@@ -191,24 +203,24 @@ def get_inputs():
191203 module ,
192204 init_args_fn_name ,
193205 default = tuple (),
194- error_msg = f"Init args function '{ init_args_fn_name } ' not found in { filepath } "
206+ error_msg = f"Init args function '{ init_args_fn_name } ' not found in { filepath } " ,
195207 )
196208 model_init_kwargs = maybe_load_and_run_callable (
197209 module ,
198210 init_kwargs_fn_name ,
199211 default = {},
200- error_msg = f"Init kwargs function '{ init_kwargs_fn_name } ' not found in { filepath } "
212+ error_msg = f"Init kwargs function '{ init_kwargs_fn_name } ' not found in { filepath } " ,
201213 )
202214 sample_args = load_and_run_callable (
203215 module ,
204216 sample_args_fn_name ,
205- f"Sample args function '{ sample_args_fn_name } ' not found in { filepath } "
217+ f"Sample args function '{ sample_args_fn_name } ' not found in { filepath } " ,
206218 )
207219 sample_kwargs = maybe_load_and_run_callable (
208220 module ,
209221 sample_kwargs_fn_name ,
210222 default = {},
211- error_msg = f"Sample kwargs function '{ sample_kwargs_fn_name } ' not found in { filepath } "
223+ error_msg = f"Sample kwargs function '{ sample_kwargs_fn_name } ' not found in { filepath } " ,
212224 )
213225
214226 nn_model : nn .Module = model (* model_init_args , ** model_init_kwargs )
0 commit comments