Skip to content

Commit e910d6d

Browse files
andrewor14facebook-github-bot
authored andcommitted
Remove internal executorch dependency on torchao.quantization.subclass
Summary: **Summary:** This is a really old quantization API that we recently removed in torchao (D84842047). No one should be calling it anymore. For BC, let's just copy the base class into executorch for now. We should delete this in the future. **Test Plan:** CI Differential Revision: D84921134
1 parent 06ea3d6 commit e910d6d

File tree

4 files changed

+173
-2
lines changed

4 files changed

+173
-2
lines changed

examples/models/llama/experimental/__init__.py

Whitespace-only changes.

examples/models/llama/experimental/load_gguf_q4_0.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file
2727
from gguf import ReaderTensor
2828
from gguf.constants import GGMLQuantizationType
29-
from torchao.quantization.subclass import QuantizedLinearWeightBase
29+
30+
from .torchao_subclass_legacy import QuantizedLinearWeightBase
3031

3132
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3233
logging.basicConfig(level=logging.INFO, format=FORMAT)

examples/models/llama/experimental/subclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#
2121
# This layout is handled internally in the tensor subclass.
2222
import torch
23-
from torchao.quantization.subclass import QuantizedLinearWeightBase
23+
from .torchao_subclass_legacy import QuantizedLinearWeightBase
2424

2525

2626
def down_size(size):
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import torch
2+
from torch.utils._python_dispatch import return_and_correct_aliasing
3+
4+
5+
aten = torch.ops.aten
6+
7+
class QuantizedLinearWeightBase(torch.Tensor):
8+
"""
9+
*** LEGACY TORCHAO TENSOR SUBCLASS ***
10+
11+
Note: this subclass no longer exists in torchao. No one should be importing or extending this
12+
subclass anymore. It is added back here just for internal executorch BC. DO NOT USE!
13+
14+
Base quantized tensor subclass for quantized linear weights. When the from_float method is used,
15+
to create an instance of any QuantizedLinearWeightBase, we assume the input
16+
weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels.
17+
18+
The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
19+
regardless of the internal representation's type or orientation.
20+
"""
21+
22+
@staticmethod
23+
def __new__(cls, int_data, transposed, shape, *args, **kwargs):
24+
kwargs["device"] = int_data.device
25+
kwargs["layout"] = (
26+
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
27+
)
28+
assert "dtype" in kwargs
29+
assert not kwargs.get("requires_grad", False)
30+
kwargs["requires_grad"] = False
31+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
32+
33+
def __init__(self, int_data, transposed, *args, **kwargs):
34+
self.int_data = int_data
35+
36+
self.transposed = transposed
37+
38+
@staticmethod
39+
def _quantized_op(act_mat, w_qtensor, bias):
40+
pass
41+
42+
def __repr__(self):
43+
return (
44+
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
45+
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
46+
)
47+
48+
def dequantize(self):
49+
pass
50+
51+
def int_repr(self):
52+
pass
53+
54+
def q_params(self):
55+
pass
56+
57+
def half(self):
58+
return self.to(torch.float16)
59+
60+
def _get_to_kwargs(self, *args, **kwargs):
61+
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
62+
device = self.device if device is None else device
63+
dtype = self.dtype if dtype is None else dtype
64+
memory_format = (
65+
memory_format if memory_format is not None else torch.preserve_format
66+
)
67+
kwargs = {
68+
"device": device,
69+
"dtype": dtype,
70+
"memory_format": memory_format,
71+
}
72+
return kwargs
73+
74+
def _apply_fn_to_data(self, fn):
75+
pass
76+
77+
def _change_shape(self):
78+
pass
79+
80+
def __tensor_flatten__(self):
81+
pass
82+
83+
@classmethod
84+
def __tensor_unflatten__(
85+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
86+
):
87+
pass
88+
89+
@classmethod
90+
def from_float(cls, input_float):
91+
pass
92+
93+
# __torch_function__ = torch._C._disabled_torch_function_impl
94+
95+
@classmethod
96+
def __torch_function__(cls, func, types, args=(), kwargs=None):
97+
kwargs = {} if kwargs is None else kwargs
98+
99+
if func is torch.nn.functional.linear:
100+
mat1, w_qtensor, bias = (
101+
args[0],
102+
args[1],
103+
args[2] if len(args) > 2 else None,
104+
)
105+
assert not w_qtensor.transposed
106+
return cls._quantized_op(mat1, w_qtensor, bias)
107+
108+
try:
109+
with torch._C.DisableTorchFunctionSubclass():
110+
return func(*args, **kwargs)
111+
except Exception:
112+
print(f"ERR: subclass doesn't implement {func}")
113+
114+
@classmethod
115+
def __torch_dispatch__(cls, func, types, args, kwargs):
116+
# two scenarios where we currently fall back to vanilla mm:
117+
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
118+
# for consistency and to allow people to test
119+
# 2 - we're given non-floats - quantizing long to int8 is crazy
120+
if (
121+
func in [aten.mm.default, aten.addmm.default]
122+
and args[0].is_floating_point()
123+
and args[0].is_cuda
124+
):
125+
if func == aten.addmm.default:
126+
assert args[1].shape[-1] == args[2].shape[0], (
127+
f"need mat1 shape: {args[1].shape} final"
128+
f"dim to match mat2 shape: {args[2].shape} first dim "
129+
)
130+
mat1, w_qtensor, bias = (
131+
args[1],
132+
args[2],
133+
args[0],
134+
)
135+
else:
136+
assert args[0].shape[-1] == args[1].shape[0], (
137+
f"need mat1 shape: {args[0].shape} final dim"
138+
f"to match mat2 shape: {args[1].shape} first dim"
139+
)
140+
mat1, w_qtensor, bias = (
141+
args[0],
142+
args[1],
143+
None if len(args) == 2 else args[2],
144+
)
145+
# call the quantized op for the specific type
146+
# of quantized tensor subclass
147+
return cls._quantized_op(mat1, w_qtensor, bias)
148+
149+
if func is aten.detach.default:
150+
return return_and_correct_aliasing(
151+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
152+
)
153+
154+
if func is aten.clone.default:
155+
return return_and_correct_aliasing(
156+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
157+
)
158+
159+
if func is aten.t.default:
160+
args[0].transposed = not args[0].transposed
161+
new = args[0]._change_shape(args[0].shape[::-1])
162+
return return_and_correct_aliasing(func, args, kwargs, new)
163+
164+
if func is aten._to_copy.default:
165+
return return_and_correct_aliasing(
166+
func,
167+
args,
168+
kwargs,
169+
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
170+
)

0 commit comments

Comments
 (0)