Skip to content

Commit f7f8298

Browse files
committed
Add int8 Woq for CPU
update int4 weight dim Add CPU profiling
1 parent 222ec25 commit f7f8298

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

mixtral-moe/generate.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch._dynamo.config
1414
import torch._inductor.config
15+
torch._inductor.config.cpp.enable_kernel_profile = True
1516

1617
def device_sync(device):
1718
if "cuda" in device:
@@ -132,7 +133,7 @@ def encode_tokens(tokenizer, string, bos=True, device='cuda'):
132133
tokens = tokenizer.encode(string)
133134
if bos:
134135
tokens = [tokenizer.bos_id()] + tokens
135-
return torch.tensor(tokens, dtype=torch.int, device=device)
136+
return torch.tensor(tokens, dtype=torch.int, device=args.device)
136137

137138
def _load_model(checkpoint_path, device, precision, use_tp):
138139
with torch.device('meta'):
@@ -248,8 +249,13 @@ def callback(x):
248249
if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
249250
prof = contextlib.nullcontext()
250251
else:
251-
torch.profiler._utils._init_for_cuda_graphs()
252-
prof = torch.profiler.profile()
252+
if device == 'cuda':
253+
torch.profiler._utils._init_for_cuda_graphs()
254+
prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], use_cuda=True)
255+
profile_sort = 'self_cuda_time_total'
256+
elif device == 'cpu':
257+
prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU])
258+
profile_sort = 'self_cpu_time_total'
253259
with prof:
254260
y = generate(
255261
model,
@@ -263,6 +269,8 @@ def callback(x):
263269
if i == -1:
264270
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
265271
continue
272+
if hasattr(prof, "key_averages"):
273+
print(prof.key_averages().table(sort_by=profile_sort, row_limit=-1))
266274
if hasattr(prof, "export_chrome_trace"):
267275
if use_tp:
268276
prof.export_chrome_trace(f"{profile}_rank_{rank}.json")

mixtral-moe/quantize.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,20 @@ def convert_for_runtime(self):
9898
return self.mod
9999

100100

101+
# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
102+
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
103+
def linear_forward_int8(x, weight_int8pack, scales, out_features):
104+
if x.is_cuda:
105+
return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales
106+
107+
origin_x_size = x.size()
108+
x = x.reshape(-1, origin_x_size[-1])
109+
c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales)
110+
new_shape = origin_x_size[:-1] + (out_features,)
111+
c = c.reshape(new_shape)
112+
return c
113+
114+
101115
class WeightOnlyBit8Linear(torch.nn.Module):
102116
__constants__ = ['in_features', 'out_features']
103117
in_features: int
@@ -115,7 +129,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
115129
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
116130

117131
def forward(self, input: torch.Tensor) -> torch.Tensor:
118-
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
132+
# return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
133+
# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
134+
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
135+
return linear_forward_int8(
136+
input,
137+
self.weight, self.scales, self.out_features)
119138

120139

121140
class ConditionalFeedForwardBit8(nn.Module):

quantize.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,18 @@ def convert_for_runtime(self):
335335
replace_linear_weight_only_int8_per_channel(self.mod)
336336
return self.mod
337337

338+
# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
339+
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
340+
def linear_forward_int8(x, weight_int8pack, scales, out_features):
341+
if x.is_cuda:
342+
return F.linear(x, weight_int8pack.to(dtype=x.dtype)) * scales
343+
344+
origin_x_size = x.size()
345+
x = x.reshape(-1, origin_x_size[-1])
346+
c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales)
347+
new_shape = origin_x_size[:-1] + (out_features,)
348+
c = c.reshape(new_shape)
349+
return c
338350

339351
class WeightOnlyInt8Linear(torch.nn.Module):
340352
__constants__ = ['in_features', 'out_features']
@@ -352,7 +364,12 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
352364
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
353365

354366
def forward(self, input: torch.Tensor) -> torch.Tensor:
355-
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
367+
# return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
368+
# TODO: This is a workaround to speedup int8 woq performance. Will remove this when
369+
# https://github.com/pytorch/pytorch/pull/120985 is in PyTorch stable release.
370+
return linear_forward_int8(
371+
input,
372+
self.weight, self.scales, self.out_features)
356373

357374
##### weight only int4 per channel groupwise quantized code ######
358375

0 commit comments

Comments
 (0)