1111from brevitas .core .scaling .runtime import RuntimeDynamicGroupStatsScaling
1212from brevitas .core .stats import AbsMinMax
1313from brevitas .core .stats import NegativeMinOrZero
14+ from brevitas .core .stats import StatsOp
1415from brevitas .core .stats .stats_op import HalfQuadraticOptimizerZeroPoint
1516from brevitas .core .stats .stats_wrapper import SCALAR_SHAPE
1617from brevitas .core .zero_point import RuntimeDynamicGroupZeroPoint
@@ -72,7 +73,7 @@ class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
7273 """
7374 scaling_impl = RuntimeDynamicStatsScaling
7475 scaling_stats_input_view_shape_impl = OverTensorView
75- scaling_stats_op = 'min_max'
76+ scaling_stats_op = StatsOp . MAX
7677 dynamic_scaling_broadcastable_fn = lambda x , shape : x .view (SCALAR_SHAPE )
7778
7879
@@ -82,7 +83,7 @@ class Fp8e4m3FNUZDynamicActPerTensorFloat(Fp8e4m3FNUZActPerTensorFloat):
8283 """
8384 scaling_impl = RuntimeDynamicStatsScaling
8485 scaling_stats_input_view_shape_impl = OverTensorView
85- scaling_stats_op = 'min_max'
86+ scaling_stats_op = StatsOp . MAX
8687 dynamic_scaling_broadcastable_fn = lambda x , shape : x .view (SCALAR_SHAPE )
8788
8889
@@ -92,7 +93,7 @@ class Int8DynamicActPerRowFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
9293 """
9394 scaling_impl = RuntimeDynamicStatsScaling
9495 scaling_stats_input_view_shape_impl = OverOutputFeaturesView
95- scaling_stats_op = 'min_max'
96+ scaling_stats_op = StatsOp . MAX
9697 scaling_per_output_channel = True
9798
9899
@@ -107,7 +108,7 @@ class Int8DynamicActPerGroupFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
107108 """
108109 proxy_class = GroupwiseActQuantProxyFromInjector
109110 scaling_impl = RuntimeDynamicGroupStatsScaling
110- scaling_stats_op = 'min_max'
111+ scaling_stats_op = StatsOp . MAX
111112 scaling_per_output_type = ScalingPerOutputType .GROUP
112113
113114
@@ -117,7 +118,7 @@ class ShiftedUint8DynamicActPerGroupFloat(DynamicActProxyMixin, ShiftedUint8ActP
117118 """
118119 proxy_class = GroupwiseActQuantProxyFromInjector
119120 scaling_impl = RuntimeDynamicGroupStatsScaling
120- scaling_stats_op = 'min_max'
121+ scaling_stats_op = StatsOp . MIN_MAX
121122 scaling_per_output_type = ScalingPerOutputType .GROUP
122123 zero_point_impl = RuntimeDynamicGroupZeroPoint
123124 zero_point_stats_impl = NegativeMinOrZero
@@ -129,7 +130,7 @@ class ShiftedUint8DynamicActPerTensorFloat(DynamicActProxyMixin, ShiftedUint8Act
129130 """
130131 scaling_impl = RuntimeDynamicStatsScaling
131132 scaling_stats_input_view_shape_impl = OverTensorView
132- scaling_stats_op = 'min_max'
133+ scaling_stats_op = StatsOp . MIN_MAX
133134 zero_point_impl = RuntimeDynamicStatsZeroPoint
134135 zero_point_stats_impl = NegativeMinOrZero
135136 dynamic_scaling_broadcastable_fn = lambda x , shape : x .view (SCALAR_SHAPE )
@@ -141,7 +142,7 @@ class ShiftedUint8DynamicActPerRowFloat(DynamicActProxyMixin, ShiftedUint8ActPer
141142 """
142143 scaling_impl = RuntimeDynamicStatsScaling
143144 scaling_stats_input_view_shape_impl = OverOutputFeaturesView
144- scaling_stats_op = 'min_max'
145+ scaling_stats_op = StatsOp . MIN_MAX
145146 scaling_per_output_channel = True
146147 zero_point_impl = RuntimeDynamicStatsZeroPoint
147148 zero_point_stats_impl = NegativeMinOrZero
@@ -154,7 +155,7 @@ class Fp8e4m3DynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3ActPerTensorFl
154155 proxy_class = GroupwiseActFloatQuantProxyFromInjector
155156 scaling_impl = RuntimeDynamicGroupStatsScaling
156157 scaling_per_output_type = ScalingPerOutputType .GROUP
157- scaling_stats_op = 'min_max'
158+ scaling_stats_op = StatsOp . MAX
158159
159160
160161class FP8e4m3OCPDynamicActPerRowFixedPoint (Fp8e4m3OCPActPerTensorFloat ):
@@ -163,7 +164,7 @@ class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3OCPActPerTensorFloat):
163164 """
164165 scaling_impl = RuntimeDynamicStatsScaling
165166 scaling_stats_input_view_shape_impl = OverOutputFeaturesView
166- scaling_stats_op = 'min_max'
167+ scaling_stats_op = StatsOp . MAX
167168 scaling_per_output_channel = True
168169 restrict_scaling_type = RestrictValueType .POWER_OF_TWO
169170 restrict_value_float_to_int_impl = FloorSte
@@ -173,7 +174,7 @@ class FP8e4m3OCPDynamicActPerRowFixedPoint(Fp8e4m3OCPActPerTensorFloat):
173174class FP8e4m3OCPDynamicActPerRowFloat (Fp8e4m3OCPActPerTensorFloat ):
174175 scaling_impl = RuntimeDynamicStatsScaling
175176 scaling_stats_input_view_shape_impl = OverOutputFeaturesView
176- scaling_stats_op = 'min_max'
177+ scaling_stats_op = StatsOp . MAX
177178 scaling_per_output_channel = True
178179 proxy_class = DynamicActFloatQuantProxyFromInjector
179180
@@ -185,7 +186,7 @@ class Fp8e4m3OCPDynamicActPerGroupFloat(DynamicActProxyMixin, Fp8e4m3OCPActPerTe
185186 proxy_class = GroupwiseActFloatQuantProxyFromInjector
186187 scaling_impl = RuntimeDynamicGroupStatsScaling
187188 scaling_per_output_type = ScalingPerOutputType .GROUP
188- scaling_stats_op = 'min_max'
189+ scaling_stats_op = StatsOp . MAX
189190
190191
191192class Fp8e4m3OCPWeightSymmetricGroupQuant (Fp8e4m3OCPWeightPerChannelFloat ):
@@ -210,7 +211,7 @@ class Fp8e4m3OCPWeightPerChannelFloatMSE(MSESymmetricScale, Fp8e4m3OCPWeightPerC
210211class FP8e4m3FNUZDynamicActPerRowFloat (Fp8e4m3FNUZActPerTensorFloat ):
211212 scaling_impl = RuntimeDynamicStatsScaling
212213 scaling_stats_input_view_shape_impl = OverOutputFeaturesView
213- scaling_stats_op = 'min_max'
214+ scaling_stats_op = StatsOp . MAX
214215 scaling_per_output_channel = True
215216 proxy_class = DynamicActFloatQuantProxyFromInjector
216217
0 commit comments