Skip to content

Commit 9f81399

Browse files
optimiseafacebook-github-bot
authored andcommitted
expose rounding_mode in quantization for performance (#4862)
Summary: X-link: facebookresearch/FBGEMM#1884 Pull Request resolved: #4862 X-link: meta-pytorch/torchrec#3368 Expose the rounding_mode for mx4 as it could impact the QPS. Previous work was done here. D62466094 ``` class RoundingMode(IntEnum): """Rounding options for quantization.""" nearest = 0 floor = 1 even = 2 stochastic = 3 ceil = 4 ``` https://fburl.com/code/8prz4mem Reviewed By: victor-eds Differential Revision: D82001579 fbshipit-source-id: 872cd8ba62292b95e568ece47ac09052f28ca59e
1 parent a0dd77b commit 9f81399

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

fbgemm_gpu/fbgemm_gpu/quantize_comm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class QuantizationContext:
6666
row_dim: int = ROW_DIM_DEFAULT
6767
row_dim_quant: int = -1
6868
mx_group_size: int = MX_GROUP_SIZE_DEFAULT
69-
rounding_mode: RoundingMode = RoundingMode.even
69+
rounding_mode: Optional[RoundingMode] = RoundingMode.even
7070
padded_dim_sum_per_rank: Optional[List[int]] = None
7171

7272

@@ -167,6 +167,7 @@ def __init__(
167167
loss_scale: Optional[float] = None,
168168
row_dim: Optional[int] = None,
169169
is_fwd: bool = True,
170+
rounding_mode: Optional[RoundingMode] = None,
170171
) -> None:
171172
if loss_scale is not None:
172173
if comm_precision not in [SparseType.FP16, SparseType.BF16]:
@@ -183,8 +184,12 @@ def __init__(
183184
self._loss_scale = loss_scale
184185
self._is_fwd = is_fwd
185186
self._row_dim: int = -1 if row_dim is None else row_dim
187+
self._rounding_mode: Optional[RoundingMode] = rounding_mode
186188
if self._comm_precision == SparseType.MX4:
187189
self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
190+
self._rounding_mode = (
191+
RoundingMode.even if rounding_mode is None else rounding_mode
192+
)
188193

189194
def encode(
190195
self, input_tensor: torch.Tensor, ctx: Optional[QuantizationContext] = None
@@ -258,7 +263,9 @@ def create_context(self) -> Optional[QuantizationContext]:
258263
return QuantizationContext(self._row_dim)
259264
if self._comm_precision == SparseType.MX4:
260265
return QuantizationContext(
261-
row_dim=self._row_dim, mx_group_size=self._row_dim
266+
row_dim=self._row_dim,
267+
mx_group_size=self._row_dim,
268+
rounding_mode=self._rounding_mode,
262269
)
263270
# int8 rowwise is default
264271
return QuantizationContext()

0 commit comments

Comments
 (0)