Skip to content

Commit 418593c

Browse files
Make scaling type configurable for MoE training (#2642)
stack-info: PR: #2642, branch: danielvegamyhre/stack/26
1 parent 3c2e229 commit 418593c

File tree

4 files changed

+173
-26
lines changed

4 files changed

+173
-26
lines changed

test/prototype/moe_training/test_training.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
)
1313

1414
from torchao.float8.float8_utils import compute_error
15-
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
15+
from torchao.prototype.moe_training.conversion_utils import (
16+
MoEScalingType,
17+
MoETrainingConfig,
18+
)
1619
from torchao.quantization.quant_api import quantize_
1720

1821
from .testing_utils import _validate_model_conversion
@@ -72,7 +75,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
7275
return False
7376

7477
# quantize test model
75-
config = MoETrainingConfig()
78+
config = MoETrainingConfig(scaling_type=MoEScalingType.FP8_ROWWISE)
7679
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
7780

7881
# validate that only the experts were converted
@@ -99,7 +102,105 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
99102

100103
# validate output
101104
out_sqnr = compute_error(out, ref_out)
102-
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
105+
min_out_sqnr = 29.0
106+
assert out_sqnr.item() >= min_out_sqnr, (
107+
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
108+
)
109+
110+
# compute loss
111+
labels = torch.ones_like(ref_out)
112+
ref_loss = F.mse_loss(ref_out, labels)
113+
out_loss = F.mse_loss(out, labels)
114+
115+
# backward pass
116+
ref_loss.backward()
117+
out_loss.backward()
118+
119+
# validate input gradient
120+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
121+
min_input_grad_sqnr = 29.0
122+
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
123+
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
124+
)
125+
126+
# validate param gradients
127+
min_param_grad_sqnr = 25.0
128+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
129+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
130+
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
131+
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
132+
)
133+
134+
135+
@pytest.mark.parametrize(
136+
"target_fqns",
137+
[
138+
["experts"],
139+
["does.not.exist"],
140+
],
141+
)
142+
def test_moe_mxfp8_training(target_fqns: list[str]):
143+
block_size = 32
144+
145+
# Token groups must be divisible by 32 for mxfp8
146+
set_token_group_alignment_size_m(block_size)
147+
148+
model_args = TransformerModelArgs(
149+
moe_enabled=True,
150+
num_experts=8,
151+
dim=256,
152+
multiple_of=block_size,
153+
ffn_dim_multiplier=1.0,
154+
)
155+
init_std = 0.02
156+
device = torch.device("cuda")
157+
158+
# reference bf16 MoE
159+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
160+
torch.manual_seed(42)
161+
ref_model.init_weights(init_std, device)
162+
163+
# target MoE for testing conversion
164+
model = copy.deepcopy(ref_model)
165+
166+
# assert starting params are identical for both models
167+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
168+
assert torch.equal(param1, param2)
169+
170+
# convert MoE to float8 training
171+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
172+
for target_fqn in target_fqns:
173+
if target_fqn in cur_fqn:
174+
return True
175+
return False
176+
177+
# quantize test model
178+
config = MoETrainingConfig(scaling_type=MoEScalingType.MXFP8)
179+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
180+
181+
# validate that only the experts were converted
182+
_validate_model_conversion(
183+
model,
184+
target_fqns=target_fqns,
185+
)
186+
187+
# inputs
188+
batch, seq, dim = 8, 2048, 256
189+
ref_x = torch.randn(
190+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
191+
)
192+
x = ref_x.detach().clone().requires_grad_(True)
193+
194+
# forward pass
195+
ref_out = ref_model(ref_x)
196+
out = model(x)
197+
198+
# validate output
199+
out_sqnr = compute_error(out, ref_out)
200+
min_out_sqnr = 25.0
201+
assert out_sqnr.item() >= min_out_sqnr, (
202+
f"SQNR must be >= {min_out_sqnr}, got {out_sqnr.item()}."
203+
)
103204

104205
# compute loss
105206
labels = torch.ones_like(ref_out)
@@ -112,13 +213,15 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
112213

113214
# validate input gradient
114215
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
115-
assert input_grad_sqnr.item() >= 30.0, (
116-
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
216+
min_input_grad_sqnr = 25.0
217+
assert input_grad_sqnr.item() >= min_input_grad_sqnr, (
218+
f"SQNR must be >= {min_input_grad_sqnr}, got {input_grad_sqnr.item()}."
117219
)
118220

119221
# validate param gradients
222+
min_param_grad_sqnr = 21.0
120223
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
121224
param_grad_sqnr = compute_error(param1.grad, param2.grad)
122-
assert param_grad_sqnr.item() >= 25.0, (
123-
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
225+
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
226+
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
124227
)

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,24 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import logging
7+
from enum import Enum
78
from typing import Callable, Optional
89

910
from torch import nn
1011

1112
from torchao.core.config import AOBaseConfig
12-
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
1313
from torchao.quantization.transform_module import (
1414
register_quantize_module_handler,
1515
)
1616

1717
logger: logging.Logger = logging.getLogger(__name__)
1818

1919

20+
class MoEScalingType(Enum):
21+
FP8_ROWWISE = "fp8_rowwise"
22+
MXFP8 = "mxfp8"
23+
24+
2025
class MoETrainingConfig(AOBaseConfig):
2126
"""
2227
The MoETrainingConfig is specifically designed to be used on MoE models using
@@ -36,6 +41,10 @@ class MoETrainingConfig(AOBaseConfig):
3641
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
3742
"""
3843

44+
def __init__(self, scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE):
45+
super().__init__()
46+
self.scaling_type = scaling_type
47+
3948

4049
@register_quantize_module_handler(MoETrainingConfig)
4150
def _moe_training_transform(
@@ -76,6 +85,8 @@ def _swap_params(
7685
Returns:
7786
nn.Module: The modified module with swapped linear layers.
7887
"""
88+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
89+
7990
if isinstance(module, nn.Parameter) and (
8091
module_filter_fn is None or module_filter_fn(module, "")
8192
):
@@ -84,7 +95,7 @@ def _swap_params(
8495
f"Does not support a root nn.Parameter with children: {module}"
8596
)
8697
if not isinstance(module.data, ScaledGroupedMMTensor):
87-
new_data = ScaledGroupedMMTensor(module.data)
98+
new_data = ScaledGroupedMMTensor(module.data, config.scaling_type)
8899
return nn.Parameter(new_data, requires_grad=module.requires_grad)
89100
return module
90101

@@ -110,7 +121,7 @@ def post_order_traversal(
110121
for param_name, param in module.named_parameters(recurse=False):
111122
if not isinstance(param.data, ScaledGroupedMMTensor):
112123
new_param = nn.Parameter(
113-
ScaledGroupedMMTensor(param.data),
124+
ScaledGroupedMMTensor(param.data, config.scaling_type),
114125
requires_grad=param.requires_grad,
115126
)
116127
setattr(module, param_name, new_param)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from torchao.float8.config import ScalingGranularity
1313
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
14+
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1415
from torchao.prototype.moe_training.kernels import (
1516
triton_fp8_col_major_jagged_colwise_scales,
1617
triton_fp8_row_major_jagged_rowwise_scales,
@@ -30,6 +31,7 @@ def _scaled_grouped_mm(
3031
B_t: torch.Tensor,
3132
offs: Optional[torch.Tensor] = None,
3233
out_dtype: Optional[torch.dtype] = torch.bfloat16,
34+
scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE,
3335
) -> torch.Tensor:
3436
"""
3537
This function performs dynamic float8 quantization with row-wise scaling
@@ -43,14 +45,27 @@ def _scaled_grouped_mm(
4345
offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor.
4446
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
4547
"""
46-
# TODO: Remove once prototype is more mature. This is currently very useful for development and debugging.
47-
logger.info("Using scaled_grouped_mm")
48-
return _Float8GroupedMM.apply(
49-
A,
50-
B_t,
51-
offs,
52-
out_dtype,
53-
)
48+
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
49+
if scaling_type == MoEScalingType.FP8_ROWWISE:
50+
logger.info("Using fp8 rowwise scaled_grouped_mm")
51+
return _Float8GroupedMM.apply(
52+
A,
53+
B_t,
54+
offs,
55+
out_dtype,
56+
)
57+
elif scaling_type == MoEScalingType.MXFP8:
58+
logger.info("Using mxfp8 scaled_grouped_mm")
59+
block_size = 32 # TODO: should we make this configurable? plumb it through in a config somehow?
60+
return _MXFP8GroupedMM.apply(
61+
A,
62+
B_t,
63+
offs,
64+
block_size,
65+
out_dtype,
66+
)
67+
else:
68+
raise ValueError(f"Unsupported scaling type {scaling_type}")
5469

5570

5671
class _Float8GroupedMM(torch.autograd.Function):

torchao/prototype/moe_training/tensor.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.distributed.fsdp import MixedPrecisionPolicy
1717

1818
from torchao.prototype.moe_training import _scaled_grouped_mm
19+
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1920

2021
logger: logging.Logger = logging.getLogger(__name__)
2122

@@ -41,15 +42,17 @@ class ScaledGroupedMMTensor(torch.Tensor):
4142
differentiable _scaled_grouped_mm autograd function.
4243
"""
4344

45+
scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE
4446
grouped_mm_func_name = "_grouped_mm"
4547
offs_arg_name = "offs"
4648

4749
@staticmethod
4850
def __new__(
4951
cls,
5052
tensor: torch.Tensor,
53+
scaling_type: MoEScalingType,
5154
):
52-
return torch.Tensor._make_wrapper_subclass(
55+
self = torch.Tensor._make_wrapper_subclass(
5356
cls,
5457
tensor.size(),
5558
strides=tensor.stride(),
@@ -61,12 +64,16 @@ def __new__(
6164
pin_memory=tensor.is_pinned(),
6265
requires_grad=tensor.requires_grad,
6366
)
67+
self.scaling_type = scaling_type
68+
return self
6469

6570
def __init__(
6671
self,
6772
tensor: torch.Tensor,
73+
scaling_type: MoEScalingType,
6874
):
6975
self._data = tensor
76+
self.scaling_type = scaling_type
7077

7178
@classmethod
7279
def __torch_function__(cls, func, types, args, kwargs={}):
@@ -80,12 +87,20 @@ def __torch_function__(cls, func, types, args, kwargs={}):
8087
# used for shared experts. This is basically the grouped_mm
8188
# kernel handling a bmm.
8289
A, B = args[0], args[1]
90+
assert not isinstance(A, ScaledGroupedMMTensor), (
91+
"A should not be a ScaledGroupedMMTensor"
92+
)
93+
assert isinstance(B, ScaledGroupedMMTensor), (
94+
"B should be a ScaledGroupedMMTensor"
95+
)
96+
scaling_type = B.scaling_type
8397
A_is_2d = A.dim() == 2
8498
B_is_3d = B.dim() == 3
8599
has_offs = kwargs.get(cls.offs_arg_name) is not None
86100
if A_is_2d and B_is_3d and has_offs:
87101
return _scaled_grouped_mm(
88102
*args,
103+
scaling_type=scaling_type,
89104
**kwargs,
90105
)
91106

@@ -97,8 +112,9 @@ def __torch_function__(cls, func, types, args, kwargs={}):
97112
@classmethod
98113
def __torch_dispatch__(cls, func, types, args, kwargs={}):
99114
# detach is special case
115+
scaling_type = args[0].scaling_type
100116
if func == torch.ops.aten.detach.default:
101-
return ScaledGroupedMMTensor(args[0]._data)
117+
return ScaledGroupedMMTensor(args[0]._data, scaling_type)
102118

103119
# unwrap args/kwargs
104120
unwrap = lambda x: x._data if isinstance(x, ScaledGroupedMMTensor) else x
@@ -116,22 +132,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs={}):
116132
# wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass
117133
return pytree.tree_map_only(
118134
torch.Tensor,
119-
lambda x: ScaledGroupedMMTensor(x),
135+
lambda x: ScaledGroupedMMTensor(x, scaling_type),
120136
out,
121137
)
122138

123139
def __repr__(self):
124-
return f"ScaledGroupedMMTensor(data={self._data})"
140+
return f"ScaledGroupedMMTensor(data={self._data}, scaling_type={self.scaling_type})"
125141

126142
def __tensor_flatten__(self):
127-
# Metadata is empty but needed to make the subclass traceable for torch.compile.
128-
metadata = {}
143+
metadata = {"scaling_type": self.scaling_type}
129144
return ["_data"], metadata
130145

131146
@staticmethod
132147
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
133148
return ScaledGroupedMMTensor(
134149
inner_tensors["_data"],
150+
flatten_spec["scaling_type"],
135151
)
136152

137153
# fsdp hooks based on https://github.com/pytorch/pytorch/blob/20e40492b046b9287726d3ec656117e4dc38f0e2/test/distributed/_composable/fsdp/test_fully_shard_extensions.py#L81
@@ -158,14 +174,16 @@ def fsdp_post_all_gather(
158174
):
159175
(data,) = all_gather_outputs
160176

161-
# For training step 1+, out=unshared param.
177+
# For training step 1+, out=unsharded param.
162178
if out is not None:
163179
if isinstance(out, ScaledGroupedMMTensor):
164180
out_data = out._data
181+
out.scaling_type = self.scaling_type
165182
elif isinstance(out, DTensor) and isinstance(
166183
out._local_tensor, ScaledGroupedMMTensor
167184
):
168185
out_data = out._local_tensor._data
186+
out._local_tensor.scaling_type = self.scaling_type
169187
else:
170188
raise RuntimeError(
171189
f"expect out to be ScaledGroupedMMTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}"
@@ -188,6 +206,6 @@ def fsdp_post_all_gather(
188206
return
189207

190208
# For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor.
191-
output = ScaledGroupedMMTensor(data)
209+
output = ScaledGroupedMMTensor(data, self.scaling_type)
192210
inner_tensors = (data,)
193211
return output, inner_tensors

0 commit comments

Comments
 (0)