Skip to content

Commit 40a5c0c

Browse files
committed
chore: Added license to ffn_tmp.py and other minor fixes
Signed-off-by: Brandon Groth <[email protected]>
1 parent 9bab835 commit 40a5c0c

File tree

6 files changed

+35
-93
lines changed

6 files changed

+35
-93
lines changed

examples/MX/ffn_tmp.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1-
# Third Party
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
214

3-
# from mx import Linear as Linear_mx # Need to amend mx's Linear class
415
# Third Party
516
import numpy as np
617
import torch

fms_mo/modules/linear.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,77 +1964,6 @@ def extra_repr(self) -> str:
19641964
"quantize_backprop": False,
19651965
}
19661966

1967-
# class QLinearMX(mx.Linear):
1968-
# """This is just a placeholder. Brandon is still working on it."""
1969-
# @classmethod
1970-
# def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
1971-
# """
1972-
# Converts a QLinear module to QLinearMX.
1973-
1974-
# Args:
1975-
# cls: The class of the QLinearMX to be created.
1976-
# fms_mo_qlinear: The QLinear module to be converted.
1977-
# kwargs: Additional keyword arguments.
1978-
1979-
# Returns:
1980-
# A QLinearMX object initialized with the weights and biases from the
1981-
# QLinear module.
1982-
# """
1983-
# mx_supported_formats = {
1984-
# "mx_fp8_e5m2",
1985-
# "mx_fp8_e4m3",
1986-
# "mx_fp4_e2m1",
1987-
# "mx_fp4",
1988-
# "mx_int8",
1989-
# "mx_int4",
1990-
# "mx_fp16",
1991-
# "mx_float16",
1992-
# "mx_bf16",
1993-
# "mx_bfloat16",
1994-
# }
1995-
# assert (
1996-
# fms_mo_qlinear.qa_mode in mx_supported_formats
1997-
# and fms_mo_qlinear.qw_mode in mx_supported_formats
1998-
# ), "Please check MX quantization mode settings!"
1999-
# a_elem_format = fms_mo_qlinear.qa_mode.removeprefix("mx_")
2000-
# w_elem_format = fms_mo_qlinear.qw_mode.removeprefix("mx_")
2001-
2002-
# block_size = kwargs.pop("block_size")
2003-
# mx_supported_block_sizes = {8, 16, 32, 64, 128}
2004-
# assert (
2005-
# block_size in mx_supported_block_sizes
2006-
# ), "Please check MX block size setting!"
2007-
2008-
# target_device = kwargs.get(
2009-
# "target_device", next(fms_mo_qlinear.parameters()).device
2010-
# )
2011-
# use_ptq = fms_mo_qlinear
2012-
2013-
# mx_specs = {
2014-
# "a_elem_format": a_elem_format,
2015-
# "w_elem_format": w_elem_format,
2016-
# "block_size": block_size,
2017-
# "bfloat": 16,
2018-
# "custom_cuda": True,
2019-
# # For quantization-aware finetuning, do backward pass in FP32
2020-
# "quantize_backprop": False,
2021-
# }
2022-
2023-
# # Create mx.Linear class from QLinear
2024-
# qlinear_mx = cls(
2025-
# in_features=fms_mo_qlinear.in_features,
2026-
# out_features=fms_mo_qlinear.out_features,
2027-
# bias=isinstance(fms_mo_qlinear.bias, torch.Tensor),
2028-
# mx_specs=fms_mo_qlinear.qcfg["mx_specs"],
2029-
# name=None,
2030-
# )
2031-
2032-
# def extra_repr(self) -> str:
2033-
# return (
2034-
# f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, "
2035-
# f"mx_spec={self.mx_spec}"
2036-
# )
2037-
20381967
class QLinearMX(torch.nn.Linear):
20391968
"""Modified from mx.linear class. Only mildly changed init() and add extra_repr.
20401969
1. Add **kwargs to receive extra (unused) params passed from qmodel_prep

fms_mo/utils/qconfig_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,18 +1182,18 @@ def check_config(config, model_dtype=None):
11821182
)
11831183

11841184
# If mapping is defined, check for MX classes
1185-
# Local
1186-
from fms_mo.modules.bmm import QBmmMX
1187-
from fms_mo.modules.linear import QLinearMX
1188-
1189-
mapping = config.get("mapping", None)
1185+
if available_packages["mx"]:
1186+
# Local
1187+
from fms_mo.modules.bmm import QBmmMX
1188+
from fms_mo.modules.linear import QLinearMX
11901189

1191-
# partial was used to init this mapping --> use .func pointer
1192-
if mapping is not None:
1193-
if not mapping[nn.Linear].func is QLinearMX:
1194-
raise ValueError("MX mapping for nn.Linear is not QLinearMX")
1190+
mapping = config.get("mapping", None)
11951191

1196-
if mapping["matmul_or_bmm"].func is QBmmMX:
1197-
raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX")
1192+
# partial was used to init this mapping --> use .func pointer
1193+
if mapping is not None:
1194+
if not mapping[nn.Linear].func is QLinearMX:
1195+
raise ValueError("MX mapping for nn.Linear is not QLinearMX")
11981196

1197+
if mapping["matmul_or_bmm"].func is QBmmMX:
1198+
raise ValueError("MX mapping for matmul_or_bmm is not QBmmMX")
11991199
# End mx_specs checks

fms_mo/utils/torchscript_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import torch
2929

3030
# Local
31-
# from fms_mo.modules import QBmm
3231
from fms_mo.quant.quantizers import transformers_prepare_input
3332
from fms_mo.utils.import_utils import available_packages
3433
from fms_mo.utils.utils import move_to, patch_torch_bmm, prepare_data_4_fwd

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ dependencies = [
3939
"safetensors",
4040
"ibm-fms>=0.0.8",
4141
"pkginfo>1.10",
42-
# "mx @ git+https://github.com/microsoft/microxcaling.git"
4342
]
4443

4544
[project.optional-dependencies]

tests/models/test_mx.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44

55
# Local
66
from fms_mo import qmodel_prep
7-
from fms_mo.modules.bmm import QBmmMX
8-
from fms_mo.modules.linear import QLinearMX
97
from fms_mo.utils.import_utils import available_packages
108
from fms_mo.utils.qconfig_utils import check_config, set_mx_specs
119
from tests.models.test_model_utils import delete_config, qmodule_error
1210

13-
mx_qmodules = [
14-
QLinearMX,
15-
QBmmMX,
16-
]
11+
if available_packages["mx"]:
12+
# Local
13+
# pylint: disable=ungrouped-imports
14+
from fms_mo.modules.bmm import QBmmMX
15+
from fms_mo.modules.linear import QLinearMX
16+
17+
mx_qmodules = [
18+
QLinearMX,
19+
QBmmMX,
20+
]
1721

1822
@pytest.mark.skipif(
1923
not available_packages["mx"],
@@ -92,7 +96,7 @@ def test_config_mx_error(
9296

9397
@pytest.mark.skipif(
9498
not torch.cuda.is_available()
95-
and not available_packages["mx"],
99+
or not available_packages["mx"],
96100
reason="Skipped because CUDA or MX library was not available",
97101
)
98102
def test_residualMLP(

0 commit comments

Comments
 (0)