Skip to content

Commit 9301123

Browse files
Merge pull request #58 from chichun-charlie-liu/triton-kernel
Add FP/INT triton kernels and unit tests, also update QAT example
2 parents 38d4b36 + 59dfc8b commit 9301123

File tree

9 files changed

+730
-35
lines changed

9 files changed

+730
-35
lines changed

.spellcheck-en-custom.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Inductor
3838
inferenced
3939
inferencing
4040
isort
41+
JIT
4142
Jupyter
4243
Kubernetes
4344
KV
@@ -105,6 +106,7 @@ Tokenized
105106
tokenizer
106107
Tokenizer
107108
toml
109+
triton
108110
Unquantized
109111
vals
110112
venv

examples/QAT_INT8/README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,16 @@ python run_qa_no_trainer_qat.py \
8787
--max_seq_length 384 \
8888
--doc_stride 128 \
8989
--attn_impl eager \
90-
--do_lowering
90+
--do_lowering <cutlass or triton>
9191
```
9292

93-
This script uses an "external kernel" instead of the `torch.matmul` kernel to perform real `INT8` matmuls. This kernel is written for Nvidia's CUDA/CUTLASS library and is compiled once just ahead of the run. The compiled artifacts are usually stored in `~/.cache/torch_extensions/`. Remove this folder if a fresh recompile of the kernel is needed.
93+
This script uses an "external kernel" instead of the `torch.matmul` kernel to perform real `INT8` matmuls. We have two options for INT kernel, one is written using Nvidia's CUDA/CUTLASS library and one is in Triton. Both will be compiled once just ahead of the run (i.e., just-in-time, JIT, compilation). The compiled artifacts are usually stored in `~/.cache/torch_extensions/`. Remove this folder if a fresh recompile of the kernel is needed.
9494

9595
Checkout [Example Test Results](#example-test-results) to compare against your results.
9696

9797
## Example Test Results
9898

99-
For comparison purposes, here are some of the results we found during testing when tested with `PyTorch 2.3.1`:
99+
For comparison purposes, here are some of the results from an A100. CUTLASS results were obtained with `PyTorch 2.3.1` while Triton results were obtained using `PyTorch 2.4.1`:
100100

101101
> [!NOTE]
102102
> Accuracy could vary ~ +-0.2 from run to run.
@@ -106,16 +106,21 @@ For comparison purposes, here are some of the results we found during testing wh
106106
|fp16|128|eager |88.21 (as fine-tuned) |126.38|
107107
| |128|Inductor | |71.59|
108108
| |128|CUDAGRAPH | |71.13|
109-
|INT8|128|eager |88.33|329.45 <sup>1</sup>|
109+
|INT8 CUTLASS|128|eager |88.33|329.45 <sup>1</sup>|
110110
| |128|Inductor |88.42|67.87 <sup>2</sup>|
111111
| |128|CUDAGRAPH |-- |-- <sup>3</sup>|
112+
|INT8 triton|128|eager |88.10|358.51|
113+
| |128|Inductor |88.13|99.91 <sup>4</sup>|
114+
| |128|CUDAGRAPH |88.13|100.21 <sup>4</sup>|
112115

113116
<sup>1</sup> `INT8` matmuls are ~2x faster than `FP16` matmuls. However, `INT8` models will have additional overhead compared to `FP16` models. For example, converting FP tensors to INT before INT matmul.
114117

115118
<sup>2</sup> Each of these additional quantization operations is relatively 'cheap', but the overhead of launching each job is not negligible. Using `torch.compile` can fuse the Ops and reduce the total number of jobs being launched.
116119

117120
<sup>3</sup> `CUDAGRAPH` is the most effective way to minimize job launching overheads and can achieve ~2X end-to-end speed-up in this case. However, there seem to be bugs associated with this option at the moment. Further investigation is still on-going.
118121

122+
<sup>4</sup> Unlike our CUTLASS `INT8` kernel, which is ~2x faster than `FP16` matmul, our Triton `INT8` is not as optimized and performs only comparable with `FP16` on mid-to-large tensor sizes.
123+
119124
## Code Walk-through
120125

121126
In this section, we will deep dive into what happens during the example steps.

examples/QAT_INT8/run_qa_no_trainer_qat.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,10 @@ def parse_args():
388388
)
389389
parser.add_argument(
390390
"--do_lowering",
391-
action="store_true",
392-
help="convert QAT model to utilize real INT8 GPU kernel",
391+
choices=["cutlass", "triton"],
392+
type=str,
393+
default="triton",
394+
help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'",
393395
)
394396

395397
args = parser.parse_args()
@@ -1086,7 +1088,7 @@ def squad_eval(model, keep_model_in_eval_mode=True):
10861088
qmodel_prep(model, exam_inp, qcfg, optimizer, use_dynamo=True)
10871089

10881090
# ---- [fms_mo] the following code are performing speed tests ----
1089-
elif args.do_lowering:
1091+
elif args.do_lowering in ["cutlass", "triton"]:
10901092
# Standard
10911093
from copy import deepcopy
10921094
import time
@@ -1158,7 +1160,11 @@ def speedtest(model, exam_inp, Ntest=100):
11581160
parent_mod = model_copy.get_submodule(parent_name)
11591161
qmod = getattr(parent_mod, module_name)
11601162
setattr(
1161-
parent_mod, module_name, QLinearINT8Deploy.from_fms_mo(qmod)
1163+
parent_mod,
1164+
module_name,
1165+
QLinearINT8Deploy.from_fms_mo(
1166+
qmod, use_int_kernel=args.do_lowering
1167+
),
11621168
)
11631169

11641170
if comp_mode is not False:
@@ -1385,6 +1391,7 @@ def speedtest(model, exam_inp, Ntest=100):
13851391
)
13861392
logger.info(f"Predict metrics: {predict_metric}")
13871393

1394+
log = {}
13881395
if args.with_tracking:
13891396
log = {
13901397
"squad_v2" if args.version_2_with_negative else "squad": eval_metric,

0 commit comments

Comments
 (0)