Skip to content

Commit 12d2882

Browse files
authored
Change act_scale -> input_scale
1 parent 009dc55 commit 12d2882

File tree

4 files changed

+26
-26
lines changed

4 files changed

+26
-26
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Each quantized layer in the state_dict will have:
8686
If the config has `"activation_scheme": "static"`:
8787
```
8888
model.layers.0.mlp.down_proj.weight < F8_E4M3
89-
model.layers.0.mlp.down_proj.act_scale < F32
89+
model.layers.0.mlp.down_proj.input_scale < F32
9090
model.layers.0.mlp.down_proj.weight_scale < F32
9191
```
9292
If config has `"activation_scheme": "dynamic"`:

auto_fp8/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _prepare_calibration_data(calibration_tokens):
115115

116116
# import copy
117117
# for layer in self.model.model.layers:
118-
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.act_scale)
118+
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale)
119119

120120
def save_quantized(self, save_dir):
121121
save_quantized_model(

auto_fp8/quantize.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,18 @@ def __init__(
104104
super().__init__()
105105
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
106106
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
107-
self.act_scale = None
107+
self.input_scale = None
108108
self.bias = bias
109109

110110
def forward(self, x):
111-
qinput, x_act_scale = per_tensor_quantize(x)
112-
if self.act_scale is None:
113-
self.act_scale = torch.nn.Parameter(x_act_scale)
114-
elif x_act_scale > self.act_scale:
115-
self.act_scale = torch.nn.Parameter(x_act_scale)
111+
qinput, x_input_scale = per_tensor_quantize(x)
112+
if self.input_scale is None:
113+
self.input_scale = torch.nn.Parameter(x_input_scale)
114+
elif x_input_scale > self.input_scale:
115+
self.input_scale = torch.nn.Parameter(x_input_scale)
116116
output = fp8_gemm(
117117
A=qinput,
118-
A_scale=self.act_scale,
118+
A_scale=self.input_scale,
119119
B=self.weight,
120120
B_scale=self.weight_scale,
121121
bias=self.bias,
@@ -130,12 +130,12 @@ def __init__(
130130
qweight: torch.Tensor,
131131
weight_scale: torch.Tensor,
132132
bias: torch.Tensor,
133-
act_scale: float = 1.0,
133+
input_scale: float = 1.0,
134134
):
135135
super().__init__()
136136
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
137137
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
138-
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
138+
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
139139
self.bias = bias
140140

141141
def per_tensor_quantize(
@@ -146,10 +146,10 @@ def per_tensor_quantize(
146146
return qweight.to(torch.float8_e4m3fn)
147147

148148
def forward(self, x):
149-
qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
149+
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
150150
output = fp8_gemm(
151151
A=qinput,
152-
A_scale=self.act_scale,
152+
A_scale=self.input_scale,
153153
B=self.weight,
154154
B_scale=self.weight_scale,
155155
bias=self.bias,
@@ -247,7 +247,7 @@ def quantize_activations(
247247
quantizer.weight,
248248
quantizer.weight_scale,
249249
quantizer.bias,
250-
quantizer.act_scale,
250+
quantizer.input_scale,
251251
)
252252
replace_module(model, name, static_proj)
253253
del quantizer

examples/quantize.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,23 @@ def __init__(self, qweight, weight_scale, bias):
8585
super().__init__()
8686
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
8787
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
88-
self.act_scale = None
88+
self.input_scale = None
8989
self.bias = bias
9090

9191
def forward(self, x):
9292
# Dynamically quantize
93-
qinput, x_act_scale = per_tensor_quantize(x)
93+
qinput, x_input_scale = per_tensor_quantize(x)
9494

9595
# Update scale if needed.
96-
if self.act_scale is None:
97-
self.act_scale = torch.nn.Parameter(x_act_scale)
98-
elif x_act_scale > self.act_scale:
99-
self.act_scale = torch.nn.Parameter(x_act_scale)
96+
if self.input_scale is None:
97+
self.input_scale = torch.nn.Parameter(x_input_scale)
98+
elif x_input_scale > self.input_scale:
99+
self.input_scale = torch.nn.Parameter(x_input_scale)
100100

101101
# Pass quantized to next layer so it has realistic data.
102102
output = fp8_gemm(
103103
A=qinput,
104-
A_scale=self.act_scale,
104+
A_scale=self.input_scale,
105105
B=self.weight,
106106
B_scale=self.weight_scale,
107107
bias=self.bias,
@@ -111,11 +111,11 @@ def forward(self, x):
111111

112112

113113
class FP8StaticLinear(torch.nn.Module):
114-
def __init__(self, qweight, weight_scale, bias, act_scale=0.0):
114+
def __init__(self, qweight, weight_scale, bias, input_scale=0.0):
115115
super().__init__()
116116
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
117117
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
118-
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
118+
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
119119
self.bias = bias
120120

121121
def per_tensor_quantize(
@@ -129,10 +129,10 @@ def per_tensor_quantize(
129129
return qweight.to(torch.float8_e4m3fn)
130130

131131
def forward(self, x):
132-
qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
132+
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
133133
output = fp8_gemm(
134134
A=qinput,
135-
A_scale=self.act_scale,
135+
A_scale=self.input_scale,
136136
B=self.weight,
137137
B_scale=self.weight_scale,
138138
bias=self.bias,
@@ -216,7 +216,7 @@ def quantize_activations(model, calibration_tokens):
216216
quantizer.weight,
217217
quantizer.weight_scale,
218218
quantizer.bias,
219-
quantizer.act_scale,
219+
quantizer.input_scale,
220220
)
221221
replace_module(model, name, static_proj)
222222
del quantizer

0 commit comments

Comments
 (0)