Skip to content

Commit e95d5bc

Browse files
authored
Fix (brevitas_examples/quantizers): correct stats for dynamic quants (#1445)
1 parent e60d648 commit e95d5bc

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/brevitas_examples/common/generative/quantizers.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling
1212
from brevitas.core.stats import AbsMinMax
1313
from brevitas.core.stats import NegativeMinOrZero
14+
from brevitas.core.stats import StatsOp
1415
from brevitas.core.stats.stats_op import HalfQuadraticOptimizerZeroPoint
1516
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
1617
from brevitas.core.zero_point import RuntimeDynamicGroupZeroPoint
@@ -72,7 +73,7 @@ class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
7273
"""
7374
scaling_impl = RuntimeDynamicStatsScaling
7475
scaling_stats_input_view_shape_impl = OverTensorView
75-
scaling_stats_op = 'min_max'
76+
scaling_stats_op = StatsOp.MAX
7677
dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(SCALAR_SHAPE)
7778

7879

@@ -82,7 +83,7 @@ class Fp8e4m3FNUZDynamicActPerTensorFloat(Fp8e4m3FNUZActPerTensorFloat):
8283
"""
8384
scaling_impl = RuntimeDynamicStatsScaling
8485
scaling_stats_input_view_shape_impl = OverTensorView
85-
scaling_stats_op = 'min_max'
86+
scaling_stats_op = StatsOp.MAX
8687
dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(SCALAR_SHAPE)
8788

8889

@@ -92,7 +93,7 @@ class Int8DynamicActPerRowFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
9293
"""
9394
scaling_impl = RuntimeDynamicStatsScaling
9495
scaling_stats_input_view_shape_impl = OverOutputFeaturesView
95-
scaling_stats_op = 'min_max'
96+
scaling_stats_op = StatsOp.MAX
9697
scaling_per_output_channel = True
9798

9899

@@ -107,7 +108,7 @@ class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
107108
"""
108109
proxy_class = GroupwiseActQuantProxyFromInjector
109110
scaling_impl = RuntimeDynamicGroupStatsScaling
110-
scaling_stats_op = 'min_max'
111+
scaling_stats_op = StatsOp.MAX
111112
scaling_per_output_type = ScalingPerOutputType.GROUP
112113

113114

@@ -117,7 +118,7 @@ class ShiftedUint8DynamicActPerGroupFloat(DynamicActProxyMixin, ShiftedUint8ActP
117118
"""
118119
proxy_class = GroupwiseActQuantProxyFromInjector
119120
scaling_impl = RuntimeDynamicGroupStatsScaling
120-
scaling_stats_op = 'min_max'
121+
scaling_stats_op = StatsOp.MIN_MAX
121122
scaling_per_output_type = ScalingPerOutputType.GROUP
122123
zero_point_impl = RuntimeDynamicGroupZeroPoint
123124
zero_point_stats_impl = NegativeMinOrZero
@@ -129,7 +130,7 @@ class ShiftedUint8DynamicActPerTensorFloat(DynamicActProxyMixin, ShiftedUint8Act
129130
"""
130131
scaling_impl = RuntimeDynamicStatsScaling
131132
scaling_stats_input_view_shape_impl = OverTensorView
132-
scaling_stats_op = 'min_max'
133+
scaling_stats_op = StatsOp.MIN_MAX
133134
zero_point_impl = RuntimeDynamicStatsZeroPoint
134135
zero_point_stats_impl = NegativeMinOrZero
135136
dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(SCALAR_SHAPE)
@@ -141,7 +142,7 @@ class ShiftedUint8DynamicActPerRowFloat(DynamicActProxyMixin, ShiftedUint8ActPer
141142
"""
142143
scaling_impl = RuntimeDynamicStatsScaling
143144
scaling_stats_input_view_shape_impl = OverOutputFeaturesView
144-
scaling_stats_op = 'min_max'
145+
scaling_stats_op = StatsOp.MIN_MAX
145146
scaling_per_output_channel = True
146147
zero_point_impl = RuntimeDynamicStatsZeroPoint
147148
zero_point_stats_impl = NegativeMinOrZero
@@ -154,7 +155,7 @@ class Fp8e4m3DynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3ActPerTensorFl
154155
proxy_class = GroupwiseActFloatQuantProxyFromInjector
155156
scaling_impl = RuntimeDynamicGroupStatsScaling
156157
scaling_per_output_type = ScalingPerOutputType.GROUP
157-
scaling_stats_op = 'min_max'
158+
scaling_stats_op = StatsOp.MAX
158159

159160

160161
class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3OCPActPerTensorFloat):
@@ -163,7 +164,7 @@ class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3OCPActPerTensorFloat):
163164
"""
164165
scaling_impl = RuntimeDynamicStatsScaling
165166
scaling_stats_input_view_shape_impl = OverOutputFeaturesView
166-
scaling_stats_op = 'min_max'
167+
scaling_stats_op = StatsOp.MAX
167168
scaling_per_output_channel = True
168169
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
169170
restrict_value_float_to_int_impl = FloorSte
@@ -173,7 +174,7 @@ class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3OCPActPerTensorFloat):
173174
class FP8e4m3OCPDynamicActPerRowFloat(Fp8e4m3OCPActPerTensorFloat):
174175
scaling_impl = RuntimeDynamicStatsScaling
175176
scaling_stats_input_view_shape_impl = OverOutputFeaturesView
176-
scaling_stats_op = 'min_max'
177+
scaling_stats_op = StatsOp.MAX
177178
scaling_per_output_channel = True
178179
proxy_class = DynamicActFloatQuantProxyFromInjector
179180

@@ -185,7 +186,7 @@ class Fp8e4m3OCPDynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3OCPActPerTe
185186
proxy_class = GroupwiseActFloatQuantProxyFromInjector
186187
scaling_impl = RuntimeDynamicGroupStatsScaling
187188
scaling_per_output_type = ScalingPerOutputType.GROUP
188-
scaling_stats_op = 'min_max'
189+
scaling_stats_op = StatsOp.MAX
189190

190191

191192
class Fp8e4m3OCPWeightSymmetricGroupQuant(Fp8e4m3OCPWeightPerChannelFloat):
@@ -210,7 +211,7 @@ class Fp8e4m3OCPWeightPerChannelFloatMSE(MSESymmetricScale, Fp8e4m3OCPWeightPerC
210211
class FP8e4m3FNUZDynamicActPerRowFloat(Fp8e4m3FNUZActPerTensorFloat):
211212
scaling_impl = RuntimeDynamicStatsScaling
212213
scaling_stats_input_view_shape_impl = OverOutputFeaturesView
213-
scaling_stats_op = 'min_max'
214+
scaling_stats_op = StatsOp.MAX
214215
scaling_per_output_channel = True
215216
proxy_class = DynamicActFloatQuantProxyFromInjector
216217

0 commit comments

Comments
 (0)