Skip to content

Commit 8e82aba

Browse files
committed
Add Ruff for lint and code formatting
Replaces previous shell script which called formatting using `precommit" command. Ruff ia a tool which manages different formatters and also flake8 linting. Signed-off-by: Martin Hickey <[email protected]>
1 parent 6b702d1 commit 8e82aba

38 files changed

+333
-352
lines changed

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ jobs:
4040
fail-fast: false
4141
matrix:
4242
lint:
43-
- name: "fmt"
43+
- name: "ruff"
4444
commands: |
45-
tox -e fmt
45+
tox -e ruff
4646
- name: "pylint"
4747
commands: |
4848
echo "::add-matcher::.github/workflows/matchers/pylint.json"

.isort.cfg

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
[settings]
2-
profile=black
3-
from_first=true
42
import_heading_future=Future
53
import_heading_stdlib=Standard
64
import_heading_thirdparty=Third Party
75
import_heading_firstparty=First Party
86
import_heading_localfolder=Local
9-
known_firstparty=
10-
known_localfolder=fms_mo,tests
117
extend_skip=fms_mo/_version.py

.pre-commit-config.yaml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
repos:
2-
- repo: https://github.com/psf/black
3-
rev: 22.3.0
4-
hooks:
5-
- id: black
6-
exclude: imports
72
- repo: https://github.com/PyCQA/isort
83
rev: 5.11.5
94
hooks:
105
- id: isort
116
exclude: imports
7+
- repo: https://github.com/astral-sh/ruff-pre-commit
8+
# Ruff version.
9+
rev: v0.5.0
10+
hooks:
11+
# Run the linter (most fixers are disabled for now).
12+
- id: ruff
13+
# Run the formatter.
14+
- id: ruff-format
15+

examples/PTQ_INT8/run_qa_no_trainer_ptq.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,43 +19,35 @@
1919
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
2020

2121
# Standard
22-
from pathlib import Path
2322
import argparse
2423
import json
2524
import logging
2625
import math
2726
import os
2827
import random
2928
import time
29+
from pathlib import Path
3030

3131
# Third Party
32+
import datasets
33+
import evaluate
34+
import numpy as np
35+
import torch
36+
import transformers
3237
from accelerate import Accelerator
3338
from accelerate.logging import get_logger
3439
from accelerate.utils import set_seed
3540
from datasets import load_dataset
3641
from huggingface_hub import HfApi
3742
from torch.utils.data import DataLoader
3843
from tqdm.auto import tqdm
39-
from transformers import (
40-
CONFIG_MAPPING,
41-
MODEL_MAPPING,
42-
AutoConfig,
43-
AutoModelForQuestionAnswering,
44-
AutoTokenizer,
45-
DataCollatorWithPadding,
46-
EvalPrediction,
47-
SchedulerType,
48-
default_data_collator,
49-
get_scheduler,
50-
)
44+
from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AutoConfig,
45+
AutoModelForQuestionAnswering, AutoTokenizer,
46+
DataCollatorWithPadding, EvalPrediction,
47+
SchedulerType, default_data_collator, get_scheduler)
5148
from transformers.utils import check_min_version, send_example_telemetry
5249
from transformers.utils.versions import require_version
5350
from utils_qa import postprocess_qa_predictions
54-
import datasets
55-
import evaluate
56-
import numpy as np
57-
import torch
58-
import transformers
5951

6052
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
6153
check_min_version("4.39.0.dev0")
@@ -1122,11 +1114,10 @@ def squad_eval(model, keep_model_in_eval_mode=True):
11221114
return eval_metric
11231115

11241116
# ---- [fms_mo] the following code are added for qat/ptq ----
1125-
# Local
1117+
# First Party
11261118
from fms_mo import qconfig_init, qmodel_prep
11271119

11281120
if args.do_qat:
1129-
11301121
# create a config dict, if same item exists in both recipe and args, args has the priority.
11311122
qcfg = qconfig_init(recipe="qat_int8", args=args)
11321123

@@ -1141,8 +1132,7 @@ def squad_eval(model, keep_model_in_eval_mode=True):
11411132
qmodel_prep(model, exam_inp, qcfg, optimizer, use_dynamo=True)
11421133

11431134
if args.do_ptq:
1144-
1145-
# Local
1135+
# First Party
11461136
from fms_mo.quant.ptq import calib_PTQ_lm
11471137

11481138
# create a config dict, if same item exists in both recipe and args, args has the priority.
@@ -1177,10 +1167,10 @@ def squad_eval(model, keep_model_in_eval_mode=True):
11771167
from copy import deepcopy
11781168

11791169
# Third Party
1180-
from torch.ao.quantization.utils import _parent_name
11811170
import pandas as pd
1171+
from torch.ao.quantization.utils import _parent_name
11821172

1183-
# Local
1173+
# First Party
11841174
from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
11851175

11861176
def speedtest(model, exam_inp, Ntest=100):
@@ -1216,17 +1206,16 @@ def speedtest(model, exam_inp, Ntest=100):
12161206
("int8", "ind"),
12171207
("int8", "cugr"),
12181208
]:
1219-
12201209
logger.info(
12211210
f"\n {label} {'with' if comp_mode else 'without'} torch.compile"
12221211
)
12231212
model_copy = deepcopy(model)
12241213

12251214
if label == "int8":
12261215
qcfg = qconfig_init(recipe="ptq_int8", args=args)
1227-
qcfg[
1228-
"qmodel_calibration"
1229-
] = 0 # no need to run calibration or trained scales will be lost.
1216+
qcfg["qmodel_calibration"] = (
1217+
0 # no need to run calibration or trained scales will be lost.
1218+
)
12301219
qmodel_prep(
12311220
model_copy,
12321221
exam_inp,
@@ -1479,9 +1468,9 @@ def speedtest(model, exam_inp, Ntest=100):
14791468
"step": completed_steps,
14801469
}
14811470
if args.do_predict:
1482-
log[
1483-
"squad_v2_predict" if args.version_2_with_negative else "squad_predict"
1484-
] = predict_metric
1471+
log["squad_v2_predict" if args.version_2_with_negative else "squad_predict"] = (
1472+
predict_metric
1473+
)
14851474

14861475
accelerator.log(log, step=completed_steps)
14871476

examples/PTQ_INT8/utils_qa.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515
"""
1616
Post-processing utilities for question answering.
1717
"""
18+
1819
# Standard
19-
from typing import Optional, Tuple
2020
import collections
2121
import json
2222
import logging
2323
import os
24+
from typing import Optional, Tuple
2425

2526
# Third Party
26-
from tqdm.auto import tqdm
2727
import numpy as np
28+
from tqdm.auto import tqdm
2829

2930
logger = logging.getLogger(__name__)
3031

examples/QAT_INT8/run_qa_no_trainer_qat.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,34 @@
1919
# You can also adapt this script on your own question answering task. Pointers for this are left as comments.
2020

2121
# Standard
22-
from pathlib import Path
2322
import argparse
2423
import json
2524
import logging
2625
import math
2726
import os
2827
import random
28+
from pathlib import Path
2929

3030
# Third Party
31+
import datasets
32+
import evaluate
33+
import numpy as np
34+
import torch
35+
import transformers
3136
from accelerate import Accelerator
3237
from accelerate.logging import get_logger
3338
from accelerate.utils import set_seed
3439
from datasets import load_dataset
3540
from huggingface_hub import HfApi
3641
from torch.utils.data import DataLoader
3742
from tqdm.auto import tqdm
38-
from transformers import (
39-
CONFIG_MAPPING,
40-
MODEL_MAPPING,
41-
AutoConfig,
42-
AutoModelForQuestionAnswering,
43-
AutoTokenizer,
44-
DataCollatorWithPadding,
45-
EvalPrediction,
46-
SchedulerType,
47-
default_data_collator,
48-
get_scheduler,
49-
)
43+
from transformers import (CONFIG_MAPPING, MODEL_MAPPING, AutoConfig,
44+
AutoModelForQuestionAnswering, AutoTokenizer,
45+
DataCollatorWithPadding, EvalPrediction,
46+
SchedulerType, default_data_collator, get_scheduler)
5047
from transformers.utils import check_min_version, send_example_telemetry
5148
from transformers.utils.versions import require_version
5249
from utils_qa import postprocess_qa_predictions
53-
import datasets
54-
import evaluate
55-
import numpy as np
56-
import torch
57-
import transformers
5850

5951
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
6052
check_min_version("4.39.0.dev0")
@@ -1068,11 +1060,10 @@ def squad_eval(model, keep_model_in_eval_mode=True):
10681060
return eval_metric
10691061

10701062
# ---- [fms_mo] the following code are added for qat/ptq ----
1071-
# Local
1063+
# First Party
10721064
from fms_mo import qconfig_init, qmodel_prep
10731065

10741066
if args.do_qat:
1075-
10761067
# create a config dict, if same item exists in both recipe and args, args has the priority.
10771068
qcfg = qconfig_init(recipe="qat_int8", args=args)
10781069

@@ -1089,14 +1080,14 @@ def squad_eval(model, keep_model_in_eval_mode=True):
10891080
# ---- [fms_mo] the following code are performing speed tests ----
10901081
elif args.do_lowering:
10911082
# Standard
1092-
from copy import deepcopy
10931083
import time
1084+
from copy import deepcopy
10941085

10951086
# Third Party
1096-
from torch.ao.quantization.utils import _parent_name
10971087
import pandas as pd
1088+
from torch.ao.quantization.utils import _parent_name
10981089

1099-
# Local
1090+
# First Party
11001091
from fms_mo.modules.linear import QLinear, QLinearINT8Deploy
11011092

11021093
def speedtest(model, exam_inp, Ntest=100):
@@ -1132,17 +1123,16 @@ def speedtest(model, exam_inp, Ntest=100):
11321123
("int8", "ind"),
11331124
("int8", "cugr"),
11341125
]:
1135-
11361126
logger.info(
11371127
f"\n {label} {'with' if comp_mode else 'without'} torch.compile"
11381128
)
11391129
model_copy = deepcopy(model)
11401130

11411131
if label == "int8":
11421132
qcfg = qconfig_init(recipe="qat_int8", args=args)
1143-
qcfg[
1144-
"qmodel_calibration"
1145-
] = 0 # no need to run calibration or trained scales will be lost.
1133+
qcfg["qmodel_calibration"] = (
1134+
0 # no need to run calibration or trained scales will be lost.
1135+
)
11461136
qmodel_prep(
11471137
model_copy,
11481138
exam_inp,
@@ -1395,9 +1385,9 @@ def speedtest(model, exam_inp, Ntest=100):
13951385
"step": completed_steps,
13961386
}
13971387
if args.do_predict:
1398-
log[
1399-
"squad_v2_predict" if args.version_2_with_negative else "squad_predict"
1400-
] = predict_metric
1388+
log["squad_v2_predict" if args.version_2_with_negative else "squad_predict"] = (
1389+
predict_metric
1390+
)
14011391

14021392
accelerator.log(log, step=completed_steps)
14031393

examples/QAT_INT8/utils_qa.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@
1515
"""
1616
Post-processing utilities for question answering.
1717
"""
18+
1819
# Standard
19-
from typing import Optional, Tuple
2020
import collections
2121
import json
2222
import logging
2323
import os
24+
from typing import Optional, Tuple
2425

2526
# Third Party
26-
from tqdm.auto import tqdm
2727
import numpy as np
28+
from tqdm.auto import tqdm
2829

2930
logger = logging.getLogger(__name__)
3031

fms_mo/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""FMS Model Optimizer init. Import most commonly used functions and classes here.
15-
"""
14+
"""FMS Model Optimizer init. Import most commonly used functions and classes here."""
1615

1716
# Standard
18-
from importlib.metadata import PackageNotFoundError, version
1917
import logging
18+
from importlib.metadata import PackageNotFoundError, version
2019

21-
# Local
20+
# First Party
2221
from fms_mo.prep import qmodel_prep
2322
from fms_mo.utils.qconfig_utils import qconfig_init
2423

fms_mo/calib.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Main user interfacing functions, such as qmodel_prep()
15-
16-
"""
14+
"""Main user interfacing functions, such as qmodel_prep()"""
1715

1816
# Standard
19-
from copy import deepcopy
20-
from typing import Callable, Tuple, Union
2117
import logging
2218
import sys
19+
from copy import deepcopy
20+
from typing import Callable, Tuple, Union
2321

2422
# Third Party
23+
import torch
2524
from torch import nn
2625
from transformers.tokenization_utils_base import BatchEncoding
27-
import torch
2826

29-
# Local
27+
# First Party
3028
from fms_mo.modules import QBmm, QConv2d, QConvTranspose2d, QLinear
3129
from fms_mo.utils.utils import prepare_data_4_fwd, prepare_inputs
3230

@@ -186,8 +184,9 @@ def __call__(self, module, inputs):
186184
for act, name in [(x, "input"), (hid[layer], "hidden")]:
187185
nelem = act.nelement()
188186
if self.a_init_method == "percentile":
189-
lower_k, upper_k = int(self.per[0] * nelem), int(
190-
self.per[1] * nelem
187+
lower_k, upper_k = (
188+
int(self.per[0] * nelem),
189+
int(self.per[1] * nelem),
191190
)
192191
lower_per_cur = (
193192
act.reshape(1, -1).kthvalue(lower_k).values.data[0]

0 commit comments

Comments
 (0)