@@ -66,7 +66,7 @@ class QuantizationContext:
66
66
row_dim : int = ROW_DIM_DEFAULT
67
67
row_dim_quant : int = - 1
68
68
mx_group_size : int = MX_GROUP_SIZE_DEFAULT
69
- rounding_mode : RoundingMode = RoundingMode .even
69
+ rounding_mode : Optional [ RoundingMode ] = RoundingMode .even
70
70
padded_dim_sum_per_rank : Optional [List [int ]] = None
71
71
72
72
@@ -167,6 +167,7 @@ def __init__(
167
167
loss_scale : Optional [float ] = None ,
168
168
row_dim : Optional [int ] = None ,
169
169
is_fwd : bool = True ,
170
+ rounding_mode : Optional [RoundingMode ] = None ,
170
171
) -> None :
171
172
if loss_scale is not None :
172
173
if comm_precision not in [SparseType .FP16 , SparseType .BF16 ]:
@@ -183,8 +184,12 @@ def __init__(
183
184
self ._loss_scale = loss_scale
184
185
self ._is_fwd = is_fwd
185
186
self ._row_dim : int = - 1 if row_dim is None else row_dim
187
+ self ._rounding_mode : Optional [RoundingMode ] = rounding_mode
186
188
if self ._comm_precision == SparseType .MX4 :
187
189
self ._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
190
+ self ._rounding_mode = (
191
+ RoundingMode .even if rounding_mode is None else rounding_mode
192
+ )
188
193
189
194
def encode (
190
195
self , input_tensor : torch .Tensor , ctx : Optional [QuantizationContext ] = None
@@ -258,7 +263,9 @@ def create_context(self) -> Optional[QuantizationContext]:
258
263
return QuantizationContext (self ._row_dim )
259
264
if self ._comm_precision == SparseType .MX4 :
260
265
return QuantizationContext (
261
- row_dim = self ._row_dim , mx_group_size = self ._row_dim
266
+ row_dim = self ._row_dim ,
267
+ mx_group_size = self ._row_dim ,
268
+ rounding_mode = self ._rounding_mode ,
262
269
)
263
270
# int8 rowwise is default
264
271
return QuantizationContext ()
0 commit comments