Skip to content

Commit d7dfc27

Browse files
rename mx example py, markdown update is still WIP
Signed-off-by: cliu-us <[email protected]>
1 parent 303454c commit d7dfc27

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

examples/MX/simple_mx_example.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Simple example using a toy model to demo how to trigger mx in fms-mo."""
15+
16+
# Third Party
17+
import numpy as np
18+
import torch
19+
import torch.nn.functional as F
20+
21+
22+
class ResidualMLP(torch.nn.Module):
23+
def __init__(self, hidden_size, device="cuda"):
24+
super(ResidualMLP, self).__init__()
25+
26+
self.layernorm = torch.nn.LayerNorm(hidden_size, device=device)
27+
self.dense_4h = torch.nn.Linear(hidden_size, 4 * hidden_size, device=device)
28+
self.dense_h = torch.nn.Linear(4 * hidden_size, hidden_size, device=device)
29+
self.dummy = torch.nn.Linear(hidden_size, hidden_size, device=device)
30+
# add a dummy layer because by default we skip 1st/last, if there are only 2 layers, all will be skipped
31+
32+
def forward(self, inputs):
33+
norm_outputs = self.layernorm(inputs)
34+
35+
# MLP
36+
proj_outputs = self.dense_4h(norm_outputs)
37+
proj_outputs = F.gelu(proj_outputs)
38+
mlp_outputs = self.dense_h(proj_outputs)
39+
mlp_outputs = self.dummy(mlp_outputs)
40+
41+
# Residual Connection
42+
outputs = inputs + mlp_outputs
43+
44+
return outputs
45+
46+
47+
if __name__ == "__main__":
48+
# Third Party
49+
from tabulate import tabulate
50+
51+
# Local
52+
from fms_mo import qconfig_init, qmodel_prep
53+
54+
HIDDEN_DIM = 128
55+
x = np.random.randn(16, HIDDEN_DIM)
56+
x = torch.tensor(x, dtype=torch.float32, device="cuda")
57+
results = {"dtype": [], "output[0, :5]": [], "||ref - out_dtype||_2": []}
58+
59+
# --- Test 0. Run MLP as is
60+
mlp = ResidualMLP(HIDDEN_DIM)
61+
# mlp.to("cuda")
62+
with torch.no_grad():
63+
out = mlp(x)
64+
results["dtype"].append("fp32")
65+
results["output[0, :5]"].append(out[0, :5].tolist())
66+
results["||ref - out_dtype||_2"].append("-")
67+
print(mlp)
68+
69+
# --- Test 1. fms-mo qmodel_prep, replace Linear with our QLinear
70+
qcfg = qconfig_init()
71+
qcfg["nbits_a"] = 8
72+
qcfg["nbits_w"] = 8
73+
model = qmodel_prep(mlp, x, qcfg)
74+
with torch.no_grad():
75+
out_dtype = model(x)
76+
results["dtype"].append("fms_int8")
77+
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
78+
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
79+
# print(model)
80+
81+
qcfg["nbits_a"] = 4
82+
qcfg["nbits_w"] = 4
83+
mlp = ResidualMLP(HIDDEN_DIM)
84+
model = qmodel_prep(mlp, x, qcfg)
85+
with torch.no_grad():
86+
out_dtype = model(x)
87+
results["dtype"].append("fms_int4")
88+
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
89+
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
90+
print(model)
91+
92+
# --- Test 2. now change mapping to MX
93+
# NOTE simply use qa_mode or qw_mode to trigger the use of mx, e.g. use "mx_" prefixed mode,
94+
# qcfg["mapping"] and other qcfg["mx_specs"] content will be updated automatically
95+
96+
for dtype_to_test in ["int8", "int4", "fp8_e4m3", "fp8_e5m2", "fp4_e2m1"]:
97+
qcfg["qw_mode"] = f"mx_{dtype_to_test}"
98+
qcfg["qa_mode"] = f"mx_{dtype_to_test}"
99+
mlp = ResidualMLP(HIDDEN_DIM) # fresh model
100+
model = qmodel_prep(mlp, x, qcfg)
101+
with torch.no_grad():
102+
out_dtype = model(x)
103+
results["dtype"].append(f"mx{dtype_to_test}")
104+
results["output[0, :5]"].append(out_dtype[0, :5].tolist())
105+
results["||ref - out_dtype||_2"].append(torch.norm(out - out_dtype).item())
106+
print(model)
107+
108+
print(tabulate(results, headers="keys"))
109+
110+
print("DONE!")

0 commit comments

Comments
 (0)