Skip to content

Commit 9b75c10

Browse files
Merge pull request #147 from BrandonGroth/dep_cleanup
build: Make non-essential dependencies optional
2 parents 4918fc9 + 07125ec commit 9b75c10

File tree

9 files changed

+104
-59
lines changed

9 files changed

+104
-59
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ The following optional dependencies are available:
104104
- `gptq`: `GPTQModel` package for W4A16 quantization
105105
- `mx`: `microxcaling` package for MX quantization
106106
- `opt`: Shortcut for `fp8`, `gptq`, and `mx` installs
107+
- `aiu`: `ibm-fms` package for AIU model deployment
107108
- `torchvision`: `torch` package for image recognition training and inference
108109
- `triton`: `triton` package for matrix multiplication kernels
110+
- `examples`: Dependencies needed for examples
109111
- `visualize`: Dependencies for visualizing models and performance data
110112
- `test`: Dependencies needed for unit testing
111113
- `dev`: Dependencies needed for development

fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,21 @@
1717
from typing import Mapping
1818

1919
# Third Party
20-
from fms.utils import serialization
2120
import torch
2221

22+
# Local
23+
from fms_mo.utils.import_utils import available_packages
24+
25+
if not available_packages["fms"]:
26+
raise ImportError(
27+
"AIU functionality requires ibm-fms to be installed."
28+
"See https://github.com/foundation-model-stack/foundation-model-stack for details."
29+
)
30+
31+
# Third Party
32+
# pylint: disable=import-error,wrong-import-position
33+
from fms.utils import serialization
34+
2335

2436
def _gptq_qweights_transpose_aiu(
2537
input_sd: Mapping[str, torch.Tensor],

fms_mo/aiu_addons/gptq/gptq_aiu_linear.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@
1818
import math
1919

2020
# Third Party
21+
import torch
22+
23+
# Local
24+
from fms_mo.utils.import_utils import available_packages
25+
26+
if not available_packages["fms"]:
27+
raise ImportError(
28+
"AIU functionality requires ibm-fms to be installed."
29+
"See https://github.com/foundation-model-stack/foundation-model-stack for details."
30+
)
31+
32+
# Third Party
33+
# pylint: disable=import-error,wrong-import-position,ungrouped-imports
2134
from fms.modules.linear import (
2235
LinearModuleShardingInfo,
2336
LinearParameterShardingInfo,
@@ -27,7 +40,6 @@
2740
)
2841
from fms.modules.tp import ShardType, TPModule
2942
from fms.utils.gptq import GPTQLinearConfig
30-
import torch
3143

3244
# Local
3345
from fms_mo.aiu_addons.gptq.gptq_aiu_op import register_aiu_gptq_op

fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,21 @@
1717
from typing import Mapping
1818

1919
# Third Party
20-
from fms.utils import serialization
2120
import torch
2221

22+
# Local
23+
from fms_mo.utils.import_utils import available_packages
24+
25+
if not available_packages["fms"]:
26+
raise ImportError(
27+
"AIU functionality requires ibm-fms to be installed."
28+
"See https://github.com/foundation-model-stack/foundation-model-stack for details."
29+
)
30+
31+
# Third Party
32+
# pylint: disable=import-error,wrong-import-position
33+
from fms.utils import serialization
34+
2335

2436
def _int8_qparams_aiu(
2537
input_sd: Mapping[str, torch.Tensor],

fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,17 @@
1919
from typing import Any, Callable, Optional, Union
2020
import copy
2121

22+
# Local
23+
from fms_mo.utils.import_utils import available_packages
24+
25+
if not available_packages["fms"]:
26+
raise ImportError(
27+
"AIU functionality requires ibm-fms to be installed."
28+
"See https://github.com/foundation-model-stack/foundation-model-stack for details."
29+
)
30+
2231
# Third Party
32+
# pylint: disable=import-error,wrong-import-position,ungrouped-imports
2333
from fms.modules.linear import (
2434
LinearModuleShardingInfo,
2535
LinearParameterShardingInfo,

fms_mo/quant/ptq.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import sys
3131

3232
# Third Party
33-
from torch.utils.tensorboard import SummaryWriter
3433
from tqdm import tqdm
3534
import numpy as np
3635
import pandas as pd
@@ -1449,49 +1448,57 @@ def ptq_mod_optim_lm(_model, m, layers, qcfg, optim_mode="both", **kwargs):
14491448
# show loss on pbar
14501449
pbar2.set_description(pbar_desc + f"{PTQloss:.6f}")
14511450

1452-
if isinstance(qcfg["tb_writer"], SummaryWriter) and isOutput:
1453-
scalars2log = {}
1454-
hist2log = {}
1451+
if available_packages["tensorboard"]:
1452+
# Third Party
1453+
from torch.utils.tensorboard import SummaryWriter
14551454

1456-
for k, v in loss4plot.items(): # plot loss
1457-
scalars2log[f"{mod_name}/PTQloss_{k}"] = v
1458-
for k, v in m.named_buffers(): # plot cv, delta, zp, alpha, and lr
1459-
if any(kb in k for kb in ["delta", "zero_point", "clip_val"]):
1460-
if len(v.shape) > 0 and v.shape[0] > 1: # perCh
1461-
hist2log[f"{mod_name}/{k}"] = v
1462-
else:
1463-
scalars2log[f"{mod_name}/{k}"] = v
1464-
for p, pname in zip(
1465-
optim_a.param_groups[0]["params"], param_names[1]
1466-
): # cva
1467-
scalars2log[f"{mod_name}/{pname}"] = p.item()
1468-
scalars2log[f"{mod_name}/LR_cv_a"] = optim_a.param_groups[0]["lr"]
1469-
for p, pname in zip(
1470-
optim_w.param_groups[0]["params"], param_names[0]
1471-
): # weights
1472-
hist2log[f"{mod_name}/{pname}"] = p
1473-
scalars2log[f"{mod_name}/LR_w"] = optim_w.param_groups[0]["lr"]
1474-
for p, pname in zip(
1475-
optim_w.param_groups[1]["params"], param_names[2]
1476-
): # cvw
1477-
if "alpha" in pname:
1478-
hist2log[f"{mod_name}/{pname}"] = p
1479-
else:
1455+
if isinstance(qcfg["tb_writer"], SummaryWriter) and isOutput:
1456+
scalars2log = {}
1457+
hist2log = {}
1458+
1459+
for k, v in loss4plot.items(): # plot loss
1460+
scalars2log[f"{mod_name}/PTQloss_{k}"] = v
1461+
for k, v in m.named_buffers(): # plot cv, delta, zp, alpha, and lr
1462+
if any(kb in k for kb in ["delta", "zero_point", "clip_val"]):
1463+
if len(v.shape) > 0 and v.shape[0] > 1: # perCh
1464+
hist2log[f"{mod_name}/{k}"] = v
1465+
else:
1466+
scalars2log[f"{mod_name}/{k}"] = v
1467+
for p, pname in zip(
1468+
optim_a.param_groups[0]["params"], param_names[1]
1469+
): # cva
14801470
scalars2log[f"{mod_name}/{pname}"] = p.item()
1481-
scalars2log[f"{mod_name}/LR_cvw"] = optim_w.param_groups[1]["lr"]
1482-
if "adaround" in qcfg["qw_mode"]:
1483-
scalars2log[f"{mod_name}/AdaR_beta"] = (
1484-
loss_func.temp_decay.curr_beta
1485-
)
1486-
for lidx, l in enumerate(layers):
1487-
if not hasattr(l, "quantize_m1"):
1488-
hist2log[f"{mod_name}/W{lidx}"] = l.weight
1471+
scalars2log[f"{mod_name}/LR_cv_a"] = optim_a.param_groups[0][
1472+
"lr"
1473+
]
1474+
for p, pname in zip(
1475+
optim_w.param_groups[0]["params"], param_names[0]
1476+
): # weights
1477+
hist2log[f"{mod_name}/{pname}"] = p
1478+
scalars2log[f"{mod_name}/LR_w"] = optim_w.param_groups[0]["lr"]
1479+
for p, pname in zip(
1480+
optim_w.param_groups[1]["params"], param_names[2]
1481+
): # cvw
1482+
if "alpha" in pname:
1483+
hist2log[f"{mod_name}/{pname}"] = p
1484+
else:
1485+
scalars2log[f"{mod_name}/{pname}"] = p.item()
1486+
scalars2log[f"{mod_name}/LR_cvw"] = optim_w.param_groups[1][
1487+
"lr"
1488+
]
1489+
if "adaround" in qcfg["qw_mode"]:
1490+
scalars2log[f"{mod_name}/AdaR_beta"] = (
1491+
loss_func.temp_decay.curr_beta
1492+
)
1493+
for lidx, l in enumerate(layers):
1494+
if not hasattr(l, "quantize_m1"):
1495+
hist2log[f"{mod_name}/W{lidx}"] = l.weight
14891496

1490-
# write every in one shot will mess up the folder, better write them one by one
1491-
for n, v in scalars2log.items():
1492-
qcfg["tb_writer"].add_scalar(n, v, Niter)
1493-
for n, v in hist2log.items():
1494-
qcfg["tb_writer"].add_histogram(n, v, Niter)
1497+
# write every in one shot will mess up the folder, better write them one by one
1498+
for n, v in scalars2log.items():
1499+
qcfg["tb_writer"].add_scalar(n, v, Niter)
1500+
for n, v in hist2log.items():
1501+
qcfg["tb_writer"].add_histogram(n, v, Niter)
14951502

14961503
for s in scheduler:
14971504
s.step() # we set up scheduler based on Nouterloop, not inner

fms_mo/run_quant.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
# Third Party
3636
from datasets import load_from_disk
37-
from huggingface_hub.errors import HFValidationError
3837
from torch.cuda import OutOfMemoryError
3938
from transformers import AutoTokenizer
4039
import torch
@@ -353,12 +352,6 @@ def main():
353352
logger.error(traceback.format_exc())
354353
write_termination_log(f"Unable to load file: {e}")
355354
sys.exit(USER_ERROR_EXIT_CODE)
356-
except HFValidationError as e:
357-
logger.error(traceback.format_exc())
358-
write_termination_log(
359-
f"There may be a problem with loading the model. Exception: {e}"
360-
)
361-
sys.exit(USER_ERROR_EXIT_CODE)
362355
except (TypeError, ValueError, EnvironmentError) as e:
363356
logger.error(traceback.format_exc())
364357
write_termination_log(

fms_mo/utils/import_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"fms",
3333
"triton",
3434
"torchvision",
35+
"huggingface_hub",
3536
]
3637

3738
available_packages = {}

pyproject.toml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,25 @@ dependencies = [
2525
"numpy>=1.26.4,<2.3.0",
2626
"accelerate>=0.20.3,!=0.34,<1.9",
2727
"transformers>=4.45,<4.53",
28-
"torch>=2.2.0,<2.6",
28+
"torch>=2.2.0,<2.6",
2929
"tqdm>=4.66.2,<5.0",
3030
"datasets>=3.0.0,<4.0",
31-
"ninja>=1.11.1.1,<2.0",
32-
"tensorboard",
33-
"notebook",
34-
"evaluate",
35-
"huggingface_hub",
3631
"pandas",
3732
"safetensors",
38-
"ibm-fms>=0.0.8",
3933
"pkginfo>1.10",
4034
]
4135

4236
[project.optional-dependencies]
37+
examples = ["ninja>=1.11.1.1,<2.0", "evaluate", "huggingface_hub"]
4338
fp8 = ["llmcompressor"]
4439
gptq = ["Cython", "gptqmodel>=1.7.3"]
4540
mx = ["microxcaling>=1.1"]
4641
opt = ["fms-model-optimizer[fp8, gptq, mx]"]
42+
aiu = ["ibm-fms>=0.0.8"]
4743
torchvision = ["torchvision>=0.17"]
4844
flash-attn = ["flash-attn>=2.5.3,<3.0"]
4945
triton = ["triton>=3.0,<3.4"]
50-
visualize = ["matplotlib", "graphviz", "pygraphviz"]
46+
visualize = ["matplotlib", "graphviz", "pygraphviz", "tensorboard", "notebook"]
5147
dev = ["pre-commit>=3.0.4,<5.0"]
5248
test = ["pytest", "pillow"]
5349

0 commit comments

Comments
 (0)