Skip to content

Commit 70348a8

Browse files
Merge branch 'main' into unit_test_int8
Signed-off-by: andrea-fasoli <[email protected]>
2 parents b29b371 + 608068d commit 70348a8

File tree

6 files changed

+24
-24
lines changed

6 files changed

+24
-24
lines changed

.github/workflows/labelpr.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,9 @@ jobs:
1212
with:
1313
github-token: ${{ secrets.GITHUB_TOKEN }}
1414
script: |
15-
const pr_welcome_msg = `Thanks for making a pull request! 😃\nOne of the maintainers will review and advise on the next steps.`;
1615
// https://github.com/commitizen/conventional-commit-types
1716
const valid_pr_types = ['feat', 'fix', 'docs', 'style', 'refactor', 'perf', 'test', 'build', 'ci', 'chore', 'revert'];
1817
19-
if(context.payload.pull_request.comments === 0) {
20-
await github.issues.createComment({ ...context.repo, issue_number: context.payload.number, body: pr_welcome_msg});
21-
}
2218
2319
const title = context.payload.pull_request.title;
2420
const results = /^(\w+)(\(\w+\))?!?:/.exec(title);
@@ -32,4 +28,4 @@ jobs:
3228
const labels = context.payload.pull_request.labels;
3329
const new_labels = labels.filter(label => !valid_pr_types.includes(label.name)); // keep all labels that are not in valid_pr_types
3430
new_labels.push({name: pr_type});
35-
await github.issues.update({ ...context.repo, issue_number: context.payload.number, labels: new_labels });
31+
await github.rest.issues.update({ ...context.repo, issue_number: context.payload.number, labels: new_labels });

examples/QAT_INT8/run_qa_no_trainer_qat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def parse_args():
390390
"--do_lowering",
391391
choices=["cutlass", "triton"],
392392
type=str,
393-
default="triton",
393+
default=None,
394394
help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'",
395395
)
396396

@@ -1136,7 +1136,7 @@ def speedtest(model, exam_inp, Ntest=100):
11361136
logger.info(
11371137
f"\n {label} {'with' if comp_mode else 'without'} torch.compile"
11381138
)
1139-
model_copy = deepcopy(model)
1139+
model_copy = deepcopy(model).half()
11401140

11411141
if label == "int8":
11421142
qcfg = qconfig_init(recipe="qat_int8", args=args)
@@ -1178,7 +1178,7 @@ def speedtest(model, exam_inp, Ntest=100):
11781178

11791179
# Median runtime using fixed input (in msec)
11801180
med_runtime = speedtest(model_copy, exam_inp)
1181-
metrics = squad_eval(model_copy) if label == "int8" else {"f1": None}
1181+
metrics = squad_eval(model_copy) # if label == "int8" else {"f1": None}
11821182

11831183
summary["precision"].append(label)
11841184
summary["compile mode"].append(comp_mode)

fms_mo/aiu_addons/__init__.py

Whitespace-only changes.

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def imatmul_kernel(
235235
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
236236
## ------ prepare LSB rounding/truncation masks -------
237237
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
238+
# msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB
238239
## ---------------------------------------------------------
239240

240241
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
@@ -326,7 +327,7 @@ def grid(META):
326327
kernel_config = {
327328
"BLOCK_SIZE_M": 128,
328329
"BLOCK_SIZE_K": chunk_size,
329-
"BLOCK_SIZE_N": 32,
330+
"BLOCK_SIZE_N": 128, # was 32
330331
"GROUP_SIZE_M": 8,
331332
"num_warps": 2,
332333
"num_stages": 5,
@@ -335,7 +336,7 @@ def grid(META):
335336
kernel_config = {
336337
"BLOCK_SIZE_M": 128,
337338
"BLOCK_SIZE_K": chunk_size,
338-
"BLOCK_SIZE_N": 64,
339+
"BLOCK_SIZE_N": 128, # was 64
339340
"GROUP_SIZE_M": 8,
340341
"num_warps": 4,
341342
"num_stages": 4,

fms_mo/modules/linear.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
752752
qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32)
753753
qlin_int.accminmax = (
754754
-(1 << (qlin_int.max_acc_bits - 1)),
755-
1 << (qlin_int.max_acc_bits - 1) - 1,
755+
(1 << (qlin_int.max_acc_bits - 1)) - 1,
756756
)
757757
qlin_int.truncate_lsb = kwargs.get("truncate_lsb", 0)
758758
qlin_int.chunk_size = kwargs.get("chunk_size", 100000)
@@ -871,16 +871,16 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):
871871

872872
qlinear_iW.nbits_a = 8 # Only support INT8 for now
873873
qlinear_iW.nbits_w = 8
874-
qlinear_iW.acc_dtype = torch.float16
874+
qlinear_iW.acc_dtype = kwargs.get("acc_dtype", torch.float)
875875
qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True)
876-
qlinear_iW.use_int_kernel = True
876+
qlinear_iW.use_int_kernel = kwargs.get("use_int_kernel", "triton")
877877
qlinear_iW.weight = nn.Parameter(
878878
nnlin_iW.weight.to(torch.int8), requires_grad=False
879879
)
880880
qlinear_iW.max_acc_bits = kwargs.get("max_acc_bits", 32)
881881
qlinear_iW.accminmax = (
882882
-(1 << (qlinear_iW.max_acc_bits - 1)),
883-
1 << (qlinear_iW.max_acc_bits - 1) - 1,
883+
(1 << (qlinear_iW.max_acc_bits - 1)) - 1,
884884
)
885885
qlinear_iW.truncate_lsb = kwargs.get("truncate_lsb", False)
886886
qlinear_iW.chunk_size = kwargs.get("chunk_size", 100000)
@@ -1027,11 +1027,11 @@ def iaddmm_int(self, bias, m1, m2):
10271027
else:
10281028
m1 = self.qa_fmo_mo_qfunc(m1)
10291029

1030-
if m1.shape[1] > self.chunk_size:
1030+
if m1.shape[1] > self.chunk_size and self.use_int_kernel != "triton":
10311031
idx = list(range(0, m1.shape[1], self.chunk_size))
10321032
Nchunk = len(idx)
10331033
idx.append(m1.shape[1])
1034-
fp16_out = torch.zeros(
1034+
accumulator = torch.zeros(
10351035
(m1.shape[0], m2.shape[1]), dtype=torch.float16, device=m1.device
10361036
)
10371037
trun_scale = 1
@@ -1052,11 +1052,11 @@ def iaddmm_int(self, bias, m1, m2):
10521052
# could cast to smaller data type to further simulate HW behavior, for example,
10531053
# if HW truncates 8b from both sides of i32 accumulator, the remaining data can
10541054
# be cast to i16 to be more realistic. pay attention to overflow handling
1055-
fp16_out += imm_out.to(torch.float16)
1055+
accumulator += imm_out.to(torch.float16)
10561056

10571057
return (
1058-
fp16_out
1059-
* (trun_scale * self.input_scale * self.w_scale).to(torch.float16)
1058+
accumulator
1059+
* (trun_scale * self.input_scale * self.w_scale) # .to(torch.float16)
10601060
+ bias
10611061
).to(self.acc_dtype)
10621062
# The safest casting, i32 -> f32
@@ -1145,10 +1145,13 @@ def extra_repr(self) -> str:
11451145
"""
11461146
Returns an alternative string representation of the object
11471147
"""
1148-
return (
1148+
repr_str = (
11491149
f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, "
1150-
f"use_int_kernel={self.use_int_kernel}"
1150+
f"int_kernel={self.use_int_kernel}"
11511151
)
1152+
if self.truncate_lsb > 0 or self.max_acc_bits < 32:
1153+
repr_str += f", acc_bits={self.max_acc_bits}, trun_lsb={self.truncate_lsb}"
1154+
return repr_str
11521155

11531156
def __getstate__(self):
11541157
"""

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@ dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
2626
"accelerate>=0.20.3,!=0.34,<1.4",
2727
"transformers>=4.45,<4.49",
28-
"torch>=2.2.0,<2.5",
28+
"torch>=2.2.0,<2.5",
29+
"triton>=3.0,<3.2",
2930
"tqdm>=4.66.2,<5.0",
3031
"datasets>=3.0.0,<4.0",
3132
"ninja>=1.11.1.1,<2.0",
3233
"tensorboard",
3334
"notebook",
34-
"torchvision>=0.8",
35+
"torchvision>=0.17",
3536
"evaluate",
3637
"huggingface_hub",
3738
"pandas",
3839
"safetensors",
39-
"ninja",
4040
"ibm-fms>=0.0.8"
4141
]
4242

0 commit comments

Comments
 (0)