Skip to content

Commit 985886f

Browse files
add fp/int triton kernels and unit tests
Signed-off-by: cliu-us <[email protected]>
1 parent e086d57 commit 985886f

File tree

7 files changed

+705
-22
lines changed

7 files changed

+705
-22
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@ fms_mo.log
3838
data_train/
3939
data_test/
4040
act_scales/
41+
examples/

examples/QAT_INT8/README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,16 @@ python run_qa_no_trainer_qat.py \
8787
--max_seq_length 384 \
8888
--doc_stride 128 \
8989
--attn_impl eager \
90-
--do_lowering
90+
--do_lowering <cutlass or triton>
9191
```
9292

93-
This script uses an "external kernel" instead of the `torch.matmul` kernel to perform real `INT8` matmuls. This kernel is written for Nvidia's CUDA/CUTLASS library and is compiled once just ahead of the run. The compiled artifacts are usually stored in `~/.cache/torch_extensions/`. Remove this folder if a fresh recompile of the kernel is needed.
93+
This script uses an "external kernel" instead of the `torch.matmul` kernel to perform real `INT8` matmuls. We have two options for INT kernel, one is written using Nvidia's CUDA/CUTLASS library and one is in Triton. Both will be compiled once just ahead of the run (i.e., just-in-time, JIT, compilation). The compiled artifacts are usually stored in `~/.cache/torch_extensions/`. Remove this folder if a fresh recompile of the kernel is needed.
9494

9595
Checkout [Example Test Results](#example-test-results) to compare against your results.
9696

9797
## Example Test Results
9898

99-
For comparison purposes, here are some of the results we found during testing when tested with `PyTorch 2.3.1`:
99+
For comparison purposes, here are some of the results from an A100. CUTLASS restuls were obtained with `PyTorch 2.3.1` while Triton results were obstained using `PyTorch 2.4.1`:
100100

101101
> [!NOTE]
102102
> Accuracy could vary ~ +-0.2 from run to run.
@@ -106,16 +106,21 @@ For comparison purposes, here are some of the results we found during testing wh
106106
|fp16|128|eager |88.21 (as fine-tuned) |126.38|
107107
| |128|Inductor | |71.59|
108108
| |128|CUDAGRAPH | |71.13|
109-
|INT8|128|eager |88.33|329.45 <sup>1</sup>|
109+
|INT8 CUTLASS|128|eager |88.33|329.45 <sup>1</sup>|
110110
| |128|Inductor |88.42|67.87 <sup>2</sup>|
111111
| |128|CUDAGRAPH |-- |-- <sup>3</sup>|
112+
|INT8 triton|128|eager |88.10|358.51|
113+
| |128|Inductor |88.13|99.91 <sup>4</sup>|
114+
| |128|CUDAGRAPH |88.13|100.21 <sup>4</sup>|
112115

113116
<sup>1</sup> `INT8` matmuls are ~2x faster than `FP16` matmuls. However, `INT8` models will have additional overhead compared to `FP16` models. For example, converting FP tensors to INT before INT matmul.
114117

115118
<sup>2</sup> Each of these additional quantization operations is relatively 'cheap', but the overhead of launching each job is not negligible. Using `torch.compile` can fuse the Ops and reduce the total number of jobs being launched.
116119

117120
<sup>3</sup> `CUDAGRAPH` is the most effective way to minimize job launching overheads and can achieve ~2X end-to-end speed-up in this case. However, there seem to be bugs associated with this option at the moment. Further investigation is still on-going.
118121

122+
<sup>4</sup> Unlike our CUTLASS `INT8` kernel, which is ~2x faster than `FP16` matmul, our Triton `INT8` is not as optimized and performs only comparable with `FP16` on mid-to-large tensor sizes.
123+
119124
## Code Walk-through
120125

121126
In this section, we will deep dive into what happens during the example steps.

examples/QAT_INT8/run_qa_no_trainer_qat.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,9 @@ def parse_args():
388388
)
389389
parser.add_argument(
390390
"--do_lowering",
391-
action="store_true",
392-
help="convert QAT model to utilize real INT8 GPU kernel",
391+
type=str,
392+
default=None,
393+
help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'",
393394
)
394395

395396
args = parser.parse_args()
@@ -1086,7 +1087,7 @@ def squad_eval(model, keep_model_in_eval_mode=True):
10861087
qmodel_prep(model, exam_inp, qcfg, optimizer, use_dynamo=True)
10871088

10881089
# ---- [fms_mo] the following code are performing speed tests ----
1089-
elif args.do_lowering:
1090+
elif args.do_lowering in ["cutlass", "triton"]:
10901091
# Standard
10911092
from copy import deepcopy
10921093
import time
@@ -1158,7 +1159,11 @@ def speedtest(model, exam_inp, Ntest=100):
11581159
parent_mod = model_copy.get_submodule(parent_name)
11591160
qmod = getattr(parent_mod, module_name)
11601161
setattr(
1161-
parent_mod, module_name, QLinearINT8Deploy.from_fms_mo(qmod)
1162+
parent_mod,
1163+
module_name,
1164+
QLinearINT8Deploy.from_fms_mo(
1165+
qmod, useINTkernel=args.do_lowering
1166+
),
11621167
)
11631168

11641169
if comp_mode is not False:
@@ -1385,6 +1390,7 @@ def speedtest(model, exam_inp, Ntest=100):
13851390
)
13861391
logger.info(f"Predict metrics: {predict_metric}")
13871392

1393+
log = {}
13881394
if args.with_tracking:
13891395
log = {
13901396
"squad_v2" if args.version_2_with_negative else "squad": eval_metric,

0 commit comments

Comments
 (0)