Skip to content

Commit 7e92c48

Browse files
authored
[CI] Linter and formatter (#24)
Adds 'ruff' as Python linter and code formatter, and 'pre-commit' for easy rules application both locally and in CI. New lint CI workflow is also added. The formatting tools are added to the 'dev' dependency group. They are installed by default with `uv sync` and can be run locally with the following cmd: pre-commit run -a
1 parent 9717077 commit 7e92c48

File tree

11 files changed

+116
-36
lines changed

11 files changed

+116
-36
lines changed

.github/workflows/lint.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
name: Lint
2+
3+
on:
4+
push:
5+
pull_request:
6+
7+
jobs:
8+
Lint:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- uses: actions/checkout@v5
12+
13+
- name: Install uv
14+
uses: astral-sh/setup-uv@v6
15+
16+
- name: Install the project and its dependencies
17+
run: |
18+
uv sync
19+
20+
- name: Run pre-commit
21+
run: |-
22+
uv run pre-commit run --all-files

.pre-commit-config.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
# Ruff version.
4+
rev: v0.14.5
5+
hooks:
6+
# Run the linter.
7+
- id: ruff-check
8+
args: [ --fix ]
9+
# Run the formatter.
10+
- id: ruff-format

ingress/mlir-gen/mlir_gen/main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def weights(
136136
assert k_as_num_inputs % block.k == 0, "invalid tile size for K dim"
137137
assert n_as_num_outputs % block.n == 0, "invalid tile size for N dim"
138138
if block.vnni:
139-
assert (
140-
block.n % block.vnni == 0
141-
), "incompatible tile sizes for N and VNNI dims"
139+
assert block.n % block.vnni == 0, (
140+
"incompatible tile sizes for N and VNNI dims"
141+
)
142142
shape = (
143143
n_as_num_outputs // block.n,
144144
k_as_num_inputs // block.k,

pyproject.toml

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ dependencies = [
88

99
[dependency-groups]
1010
dev = [
11-
"lit==18.1.8" # Tool to configure, discover and run tests
11+
"lit==18.1.8", # Tool to configure, discover and run tests
12+
"ruff==0.14.5", # Python linter and formatter
13+
"pre-commit", # Tool to manage and apply pre-commit hooks
1214
]
1315

1416
[project.optional-dependencies]
@@ -82,3 +84,33 @@ include = ["lighthouse*"]
8284

8385
[tool.setuptools.dynamic]
8486
version = {attr = "lighthouse.__version__"}
87+
88+
[tool.ruff]
89+
src = ["lighthouse"]
90+
target-version = "py310"
91+
line-length = 88
92+
93+
[tool.ruff.format]
94+
docstring-code-format = true
95+
quote-style = "double"
96+
97+
# List of rules:
98+
# https://docs.astral.sh/ruff/rules/
99+
[tool.ruff.lint]
100+
select = [
101+
"D419", # empty-docstring
102+
"E", # Error
103+
"F", # Pyflakes
104+
"PERF", # Perflint
105+
"RUF022", # __all__ is not sorted
106+
"RUF030", # print() call in assert
107+
"RUF034", # useless if-else
108+
"RUF047", # empty else
109+
"RUF200", # invalid pyproject.toml
110+
"W", # Warning
111+
]
112+
ignore = [
113+
"E501", # line-too-long
114+
"PERF203", # try-except-in-loop
115+
"PERF401", # manual-list-comprehension
116+
]

python/examples/ingress/torch/MLPModel/model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,11 @@
33
import torch
44
import torch.nn as nn
55

6-
import os
76

87
class MLPModel(nn.Module):
98
def __init__(self):
109
super().__init__()
11-
self.net = nn.Sequential(
12-
nn.Linear(10, 32),
13-
nn.ReLU(),
14-
nn.Linear(32, 2)
15-
)
10+
self.net = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 2))
1611

1712
def forward(self, x):
1813
return self.net(x)

python/examples/ingress/torch/mlp_from_file.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@
3434
# - Loads the MLPModel class and instantiates it with arguments obtained from 'get_init_inputs()'
3535
# - Calls get_sample_inputs() to get sample input tensors for shape inference
3636
# - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir
37+
# fmt: off
3738
mlir_module_ir: ir.Module = import_from_file(
3839
model_path, # Path to the Python file containing the model
3940
model_class_name="MLPModel", # Name of the PyTorch nn.Module class to convert
4041
init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__()
4142
sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)'
4243
dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types)
43-
ir_context=ir_context # MLIR context for the conversion
44+
ir_context=ir_context, # MLIR context for the conversion
4445
)
46+
# fmt: on
4547

4648
# The PyTorch model is now converted to MLIR at this point. You can now convert
4749
# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file.

python/examples/ingress/torch/mlp_from_model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
ir_context = ir.Context()
3232
# Step 2: Convert the PyTorch model to MLIR
3333
mlir_module_ir: ir.Module = import_from_model(
34-
model,
35-
sample_args=(sample_input,),
36-
ir_context=ir_context
34+
model, sample_args=(sample_input,), ir_context=ir_context
3735
)
3836

3937
# The PyTorch model is now converted to MLIR at this point. You can now convert
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
"""Provides functions to convert PyTorch models to MLIR."""
22

33
from .importer import import_from_file, import_from_model
4+
5+
__all__ = [
6+
"import_from_file",
7+
"import_from_model",
8+
]

python/lighthouse/ingress/torch/importer.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from pathlib import Path
44
from typing import Iterable, Mapping
55

6-
from lighthouse.ingress.torch.utils import load_and_run_callable, maybe_load_and_run_callable
6+
from lighthouse.ingress.torch.utils import (
7+
load_and_run_callable,
8+
maybe_load_and_run_callable,
9+
)
710

811
try:
912
import torch
@@ -25,6 +28,7 @@
2528

2629
from mlir import ir
2730

31+
2832
def import_from_model(
2933
model: nn.Module,
3034
sample_args: Iterable,
@@ -49,10 +53,10 @@ def import_from_model(
4953
ir_context (ir.Context, optional): An optional MLIR context to use for parsing
5054
the module. If not provided, the module is returned as a string.
5155
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.
52-
56+
5357
Returns:
5458
str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.
55-
59+
5660
Examples:
5761
>>> import torch
5862
>>> import torch.nn as nn
@@ -61,17 +65,22 @@ def import_from_model(
6165
... def __init__(self):
6266
... super().__init__()
6367
... self.fc = nn.Linear(10, 5)
68+
...
6469
... def forward(self, x):
6570
... return self.fc(x)
6671
>>> model = SimpleModel()
6772
>>> sample_input = (torch.randn(1, 10),)
6873
>>> #
6974
>>> # option 1: get MLIR module as a string
70-
>>> mlir_module : str = import_from_model(model, sample_input, dialect="linalg-on-tensors")
71-
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
75+
>>> mlir_module: str = import_from_model(
76+
... model, sample_input, dialect="linalg-on-tensors"
77+
... )
78+
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
7279
>>> # option 2: get MLIR module as an ir.Module
7380
>>> ir_context = ir.Context()
74-
>>> mlir_module_ir : ir.Module = import_from_model(model, sample_input, dialect="tosa", ir_context=ir_context)
81+
>>> mlir_module_ir: ir.Module = import_from_model(
82+
... model, sample_input, dialect="tosa", ir_context=ir_context
83+
... )
7584
"""
7685
if dialect == "linalg":
7786
raise ValueError(
@@ -134,45 +143,48 @@ def import_from_file(
134143
ir_context (ir.Context, optional): An optional MLIR context to use for parsing
135144
the module. If not provided, the module is returned as a string.
136145
**kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function.
137-
146+
138147
Returns:
139148
str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided.
140-
149+
141150
Examples:
142151
Given a file `path/to/model_file.py` with the following content:
143152
```python
144153
import torch
145154
import torch.nn as nn
146155
156+
147157
class MyModel(nn.Module):
148158
def __init__(self):
149159
super().__init__()
150160
self.fc = nn.Linear(10, 5)
161+
151162
def forward(self, x):
152163
return self.fc(x)
153164
165+
154166
def get_inputs():
155167
return (torch.randn(1, 10),)
156168
```
157169
158170
The import script would look like:
159171
>>> from lighthouse.ingress.torch_import import import_from_file
160172
>>> # option 1: get MLIR module as a string
161-
>>> mlir_module : str = import_from_file(
173+
>>> mlir_module: str = import_from_file(
162174
... "path/to/model_file.py",
163175
... model_class_name="MyModel",
164176
... init_args_fn_name=None,
165-
... dialect="linalg-on-tensors"
177+
... dialect="linalg-on-tensors",
166178
... )
167-
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
179+
>>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect
168180
>>> # option 2: get MLIR module as an ir.Module
169181
>>> ir_context = ir.Context()
170-
>>> mlir_module_ir : ir.Module = import_from_file(
182+
>>> mlir_module_ir: ir.Module = import_from_file(
171183
... "path/to/model_file.py",
172184
... model_class_name="MyModel",
173185
... init_args_fn_name=None,
174186
... dialect="linalg-on-tensors",
175-
... ir_context=ir_context
187+
... ir_context=ir_context,
176188
... )
177189
"""
178190
if isinstance(filepath, str):
@@ -191,24 +203,24 @@ def get_inputs():
191203
module,
192204
init_args_fn_name,
193205
default=tuple(),
194-
error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}"
206+
error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}",
195207
)
196208
model_init_kwargs = maybe_load_and_run_callable(
197209
module,
198210
init_kwargs_fn_name,
199211
default={},
200-
error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}"
212+
error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}",
201213
)
202214
sample_args = load_and_run_callable(
203215
module,
204216
sample_args_fn_name,
205-
f"Sample args function '{sample_args_fn_name}' not found in {filepath}"
217+
f"Sample args function '{sample_args_fn_name}' not found in {filepath}",
206218
)
207219
sample_kwargs = maybe_load_and_run_callable(
208220
module,
209221
sample_kwargs_fn_name,
210222
default={},
211-
error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}"
223+
error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}",
212224
)
213225

214226
nn_model: nn.Module = model(*model_init_args, **model_init_kwargs)

python/lighthouse/ingress/torch/utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,4 @@ def maybe_load_and_run_callable(
4343
"""
4444
if symbol_name is None:
4545
return default
46-
return load_and_run_callable(
47-
module,
48-
symbol_name,
49-
error_msg=error_msg
50-
)
46+
return load_and_run_callable(module, symbol_name, error_msg=error_msg)

0 commit comments

Comments
 (0)