Skip to content

Commit f4ec836

Browse files
committed
fix merge conflicts
Signed-off-by: Antoni Viros i Martin <[email protected]>
2 parents cf2082e + e8f35bb commit f4ec836

File tree

14 files changed

+126
-67
lines changed

14 files changed

+126
-67
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ The following optional dependencies are available:
104104
- `gptq`: `GPTQModel` package for W4A16 quantization
105105
- `mx`: `microxcaling` package for MX quantization
106106
- `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs
107+
- `aiu`: `ibm-fms` package for AIU model deployment
107108
- `torchvision`: `torch` package for image recognition training and inference
108109
- `triton`: `triton` package for matrix multiplication kernels
110+
- `examples`: Dependencies needed for examples
109111
- `visualize`: Dependencies for visualizing models and performance data
110112
- `test`: Dependencies needed for unit testing
111113
- `dev`: Dependencies needed for development

fms_mo/aiu_addons/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ def _infer_quantization_config(quant_config: dict) -> dict | None:
2121
and quant_config["config_groups"]["group_0"]["weights"]["num_bits"] == 8
2222
):
2323
# First, import required FP8 linear classes from fms-mo
24-
import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import
24+
# Local
2525
import fms_mo.aiu_addons.fp8.fp8_adapter # pylint: disable=unused-import
26+
import fms_mo.aiu_addons.fp8.fp8_linear # pylint: disable=unused-import
27+
2628
# This is used by get_linear to decide whether a linear layer
2729
# will be quantized or not inside the model
2830
def fp8_linear_type(name: str) -> str:

fms_mo/aiu_addons/fp8/fp8_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _math_fp8_compute_op(
173173
)
174174

175175
attn_weight = (
176-
torch.ops.sendnn.scaled_bmm(
176+
torch.ops.spyre.scaled_bmm(
177177
query,
178178
key_cache.transpose(-2, -1),
179179
q_scale,

fms_mo/aiu_addons/fp8/fp8_spyre_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
@torch.library.custom_op("spyre::scaled_bmm", mutates_args=())
30-
def sendnn_scaled_bmm(
30+
def spyre_scaled_bmm(
3131
mat1: Tensor,
3232
mat2: Tensor,
3333
scale1: Tensor,
@@ -43,6 +43,8 @@ def sendnn_scaled_bmm(
4343
assert (
4444
mat1.shape[:-2] == mat2.shape[:-2]
4545
), "batch dimensions must match for mat1 and mat2"
46+
assert scale1.numel() == 1, "only per-tensor scales supported"
47+
assert scale2.numel() == 1, "only per-tensor scales supported"
4648
mat1 = mat1.view(-1, *mat1.shape[-2:])
4749
mat2 = mat2.view(-1, *mat2.shape[-2:])
4850
out = torch.empty(
@@ -62,7 +64,7 @@ def sendnn_scaled_bmm(
6264
return out.view(*mat1.shape[:-2], mat1.shape[1], mat2.shape[2])
6365

6466

65-
@sendnn_scaled_bmm.register_fake
67+
@spyre_scaled_bmm.register_fake
6668
def _(
6769
mat1: Tensor,
6870
mat2: Tensor,

fms_mo/aiu_addons/fp8/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
self._scaled = scaled
7373

7474
def __tensor_flatten__(self):
75-
ctx = {"scaled", self._scaled}
75+
ctx = {"scaled": self._scaled}
7676
return ["_data", "_scale"], ctx
7777

7878
@staticmethod

fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,21 @@
1717
from typing import Mapping
1818

1919
# Third Party
20-
from fms.utils import serialization
2120
import torch
2221

22+
# Local
23+
from fms_mo.utils.import_utils import available_packages
24+
25+
if not available_packages["fms"]:
26+
raise ImportError(
27+
"AIU functionality requires ibm-fms to be installed."
28+
"See https://github.com/foundation-model-stack/foundation-model-stack for details."
29+
)
30+
31+
# Third Party
32+
# pylint: disable=import-error,wrong-import-position
33+
from fms.utils import serialization
34+
2335

2436
def _gptq_qweights_transpose_aiu(
2537
input_sd: Mapping[str, torch.Tensor],

fms_mo/aiu_addons/gptq/gptq_aiu_linear.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@
1818
import math
1919

2020
# Third Party
21+
import torch
22+
23+
# Local
24+
from fms_mo.utils.import_utils import available_packages
25+
26+
if not available_packages["fms"]:
27+
raise ImportError(
28+
"AIU functionality requires ibm-fms to be installed."
29+
"See https://github.com/foundation-model-stack/foundation-model-stack for details."
30+
)
31+
32+
# Third Party
33+
# pylint: disable=import-error,wrong-import-position,ungrouped-imports
2134
from fms.modules.linear import (
2235
LinearModuleShardingInfo,
2336
LinearParameterShardingInfo,
@@ -27,7 +40,6 @@
2740
)
2841
from fms.modules.tp import ShardType, TPModule
2942
from fms.utils.gptq import GPTQLinearConfig
30-
import torch
3143

3244
# Local
3345
from fms_mo.aiu_addons.gptq.gptq_aiu_op import register_aiu_gptq_op

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,21 @@
1717
from typing import Mapping
1818

1919
# Third Party
20-
from fms.utils import serialization
2120
import torch
2221

22+
# Local
23+
from fms_mo.utils.import_utils import available_packages
24+
25+
if not available_packages["fms"]:
26+
raise ImportError(
27+
"AIU functionality requires ibm-fms to be installed."
28+
"See https://github.com/foundation-model-stack/foundation-model-stack for details."
29+
)
30+
31+
# Third Party
32+
# pylint: disable=import-error,wrong-import-position
33+
from fms.utils import serialization
34+
2335

2436
def _int8_qparams_aiu(
2537
input_sd: Mapping[str, torch.Tensor],

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,17 @@
1919
from typing import Any, Callable, Optional, Union
2020
import copy
2121

22+
# Local
23+
from fms_mo.utils.import_utils import available_packages
24+
25+
if not available_packages["fms"]:
26+
raise ImportError(
27+
"AIU functionality requires ibm-fms to be installed."
28+
"See https://github.com/foundation-model-stack/foundation-model-stack for details."
29+
)
30+
2231
# Third Party
32+
# pylint: disable=import-error,wrong-import-position,ungrouped-imports
2333
from fms.modules.linear import (
2434
LinearModuleShardingInfo,
2535
LinearParameterShardingInfo,

fms_mo/quant/ptq.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import sys
3131

3232
# Third Party
33-
from torch.utils.tensorboard import SummaryWriter
3433
from tqdm import tqdm
3534
import numpy as np
3635
import pandas as pd
@@ -1449,49 +1448,57 @@ def ptq_mod_optim_lm(_model, m, layers, qcfg, optim_mode="both", **kwargs):
14491448
# show loss on pbar
14501449
pbar2.set_description(pbar_desc + f"{PTQloss:.6f}")
14511450

1452-
if isinstance(qcfg["tb_writer"], SummaryWriter) and isOutput:
1453-
scalars2log = {}
1454-
hist2log = {}
1451+
if available_packages["tensorboard"]:
1452+
# Third Party
1453+
from torch.utils.tensorboard import SummaryWriter
14551454

1456-
for k, v in loss4plot.items(): # plot loss
1457-
scalars2log[f"{mod_name}/PTQloss_{k}"] = v
1458-
for k, v in m.named_buffers(): # plot cv, delta, zp, alpha, and lr
1459-
if any(kb in k for kb in ["delta", "zero_point", "clip_val"]):
1460-
if len(v.shape) > 0 and v.shape[0] > 1: # perCh
1461-
hist2log[f"{mod_name}/{k}"] = v
1462-
else:
1463-
scalars2log[f"{mod_name}/{k}"] = v
1464-
for p, pname in zip(
1465-
optim_a.param_groups[0]["params"], param_names[1]
1466-
): # cva
1467-
scalars2log[f"{mod_name}/{pname}"] = p.item()
1468-
scalars2log[f"{mod_name}/LR_cv_a"] = optim_a.param_groups[0]["lr"]
1469-
for p, pname in zip(
1470-
optim_w.param_groups[0]["params"], param_names[0]
1471-
): # weights
1472-
hist2log[f"{mod_name}/{pname}"] = p
1473-
scalars2log[f"{mod_name}/LR_w"] = optim_w.param_groups[0]["lr"]
1474-
for p, pname in zip(
1475-
optim_w.param_groups[1]["params"], param_names[2]
1476-
): # cvw
1477-
if "alpha" in pname:
1478-
hist2log[f"{mod_name}/{pname}"] = p
1479-
else:
1455+
if isinstance(qcfg["tb_writer"], SummaryWriter) and isOutput:
1456+
scalars2log = {}
1457+
hist2log = {}
1458+
1459+
for k, v in loss4plot.items(): # plot loss
1460+
scalars2log[f"{mod_name}/PTQloss_{k}"] = v
1461+
for k, v in m.named_buffers(): # plot cv, delta, zp, alpha, and lr
1462+
if any(kb in k for kb in ["delta", "zero_point", "clip_val"]):
1463+
if len(v.shape) > 0 and v.shape[0] > 1: # perCh
1464+
hist2log[f"{mod_name}/{k}"] = v
1465+
else:
1466+
scalars2log[f"{mod_name}/{k}"] = v
1467+
for p, pname in zip(
1468+
optim_a.param_groups[0]["params"], param_names[1]
1469+
): # cva
14801470
scalars2log[f"{mod_name}/{pname}"] = p.item()
1481-
scalars2log[f"{mod_name}/LR_cvw"] = optim_w.param_groups[1]["lr"]
1482-
if "adaround" in qcfg["qw_mode"]:
1483-
scalars2log[f"{mod_name}/AdaR_beta"] = (
1484-
loss_func.temp_decay.curr_beta
1485-
)
1486-
for lidx, l in enumerate(layers):
1487-
if not hasattr(l, "quantize_m1"):
1488-
hist2log[f"{mod_name}/W{lidx}"] = l.weight
1471+
scalars2log[f"{mod_name}/LR_cv_a"] = optim_a.param_groups[0][
1472+
"lr"
1473+
]
1474+
for p, pname in zip(
1475+
optim_w.param_groups[0]["params"], param_names[0]
1476+
): # weights
1477+
hist2log[f"{mod_name}/{pname}"] = p
1478+
scalars2log[f"{mod_name}/LR_w"] = optim_w.param_groups[0]["lr"]
1479+
for p, pname in zip(
1480+
optim_w.param_groups[1]["params"], param_names[2]
1481+
): # cvw
1482+
if "alpha" in pname:
1483+
hist2log[f"{mod_name}/{pname}"] = p
1484+
else:
1485+
scalars2log[f"{mod_name}/{pname}"] = p.item()
1486+
scalars2log[f"{mod_name}/LR_cvw"] = optim_w.param_groups[1][
1487+
"lr"
1488+
]
1489+
if "adaround" in qcfg["qw_mode"]:
1490+
scalars2log[f"{mod_name}/AdaR_beta"] = (
1491+
loss_func.temp_decay.curr_beta
1492+
)
1493+
for lidx, l in enumerate(layers):
1494+
if not hasattr(l, "quantize_m1"):
1495+
hist2log[f"{mod_name}/W{lidx}"] = l.weight
14891496

1490-
# write every in one shot will mess up the folder, better write them one by one
1491-
for n, v in scalars2log.items():
1492-
qcfg["tb_writer"].add_scalar(n, v, Niter)
1493-
for n, v in hist2log.items():
1494-
qcfg["tb_writer"].add_histogram(n, v, Niter)
1497+
# write every in one shot will mess up the folder, better write them one by one
1498+
for n, v in scalars2log.items():
1499+
qcfg["tb_writer"].add_scalar(n, v, Niter)
1500+
for n, v in hist2log.items():
1501+
qcfg["tb_writer"].add_histogram(n, v, Niter)
14951502

14961503
for s in scheduler:
14971504
s.step() # we set up scheduler based on Nouterloop, not inner

0 commit comments

Comments
 (0)