Skip to content

Commit 0b4fdb3

Browse files
committed
support embeddings
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 97f237e commit 0b4fdb3

File tree

5 files changed

+65
-7
lines changed

5 files changed

+65
-7
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,8 @@ def input_hook(_, args):
117117
TransformLocation.WEIGHT_INPUT,
118118
TransformLocation.WEIGHT_OUTPUT,
119119
):
120-
assert isinstance(module, torch.nn.Linear)
121-
assert module.bias is None
122-
123120
# fuse transform into weight
121+
assert hasattr(module, "weight")
124122
with torch.no_grad(), align_module_device(module):
125123
update_offload_parameter(module, "weight", transform(module.weight))
126124

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param module: parent module that transform will be applied to
5252
:param args: defines how the transform will be applied to the module
5353
"""
54-
assert isinstance(module, Linear)
54+
assert hasattr(module, "weight")
5555
size = get_transform_size(module, args.location, self.scheme.head_dim)
5656
dtype = module.weight.dtype
5757
device = get_offloaded_device(module)

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5050
:param module: parent module that transform will be applied to
5151
:param args: defines how the transform will be applied to the module
5252
"""
53-
assert isinstance(module, Linear)
53+
assert hasattr(module, "weight")
5454
size = get_transform_size(module, args.location, self.scheme.head_dim)
5555
dtype = module.weight.dtype
5656
device = get_offloaded_device(module)

src/compressed_tensors/transform/utils/matrix.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ def get_transform_size(
3939
size = module.in_features
4040
else:
4141
size = module.out_features
42+
elif isinstance(module, torch.nn.Embedding):
43+
if location in (TransformLocation.INPUT, TransformLocation.WEIGHT_INPUT):
44+
size = module.num_embeddings
45+
else:
46+
size = module.embedding_dim
4247
else:
4348
raise NotImplementedError(f"Transforms on {type(module)} are not supported")
4449

@@ -64,7 +69,12 @@ def apply_transform_weight(
6469
:param value: value to apply weight to
6570
:param location: determines how weight should be applied
6671
:param model_type: result of type(module), passed in to determine application of
67-
weight transform
72+
weight transform. This is needed because torch uses convention:
73+
- torch.nn.Linear(in_features,out_features) has weight shape
74+
(out_features, in_features)
75+
- torch.nn.Embedding(num_embeddings, embedding_dim) has weight shape
76+
(num_embeddings, embedding_dim)
77+
The transform has to account for Linear's transposed weights
6878
:return: value after weight has been applied
6979
"""
7080
fn, axis = _get_transform_method(module_type, location)
@@ -139,6 +149,24 @@ def _get_transform_method(
139149
fn = lambda weight, value: value @ weight
140150
axis = -1
141151

152+
# similar derivation to torch.nn.Linear, but `y = (x W)`
153+
if module_type == torch.nn.Embedding:
154+
if location == TransformLocation.INPUT:
155+
fn = lambda weight, value: value @ weight
156+
axis = -1
157+
158+
elif location == TransformLocation.WEIGHT_INPUT:
159+
fn = lambda weight, value: weight @ value
160+
axis = -1
161+
162+
elif location == TransformLocation.WEIGHT_OUTPUT:
163+
fn = lambda weight, value: value @ weight
164+
axis = -1
165+
166+
elif location == TransformLocation.OUTPUT:
167+
fn = lambda weight, value: value @ weight
168+
axis = -1
169+
142170
if fn is None:
143171
raise NotImplementedError(
144172
f"Applying transforms to {module_type} {location} is not supported"

tests/test_transform/factory/test_correctness.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
@pytest.mark.parametrize("head_dim", (None, 2, 4))
3232
def test_correctness_linear(type, randomized, head_dim):
3333
size = (4, 8)
34-
module = torch.nn.Linear(*size, bias=True)
34+
module = torch.nn.Linear(*size, bias=False)
3535
scheme = TransformScheme(type=type, randomized=randomized, head_dim=head_dim)
3636
factory = TransformFactory.from_scheme(scheme, name="")
3737

@@ -56,6 +56,38 @@ def test_correctness_linear(type, randomized, head_dim):
5656
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
5757

5858

59+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
60+
@pytest.mark.parametrize("randomized", (True, False))
61+
@pytest.mark.parametrize("embed_loc", ("weight_output", "output"))
62+
@pytest.mark.parametrize("linear_loc", ("input", "weight_input"))
63+
def test_correctness_embedding(type, randomized, embed_loc, linear_loc):
64+
model = torch.nn.Sequential(
65+
torch.nn.Embedding(2, 4),
66+
torch.nn.Linear(4, 8, bias=False),
67+
)
68+
69+
input = torch.randint(high=1, low=0, size=(17, 5, 2))
70+
true_output = model(input)
71+
72+
config = TransformConfig(
73+
config_groups={
74+
"": TransformScheme(
75+
type=type,
76+
randomized=randomized,
77+
apply=[
78+
TransformArgs(targets="Embedding", location=embed_loc),
79+
TransformArgs(targets="Linear", location=linear_loc, inverse=True),
80+
],
81+
)
82+
}
83+
)
84+
apply_transform_config(model, config)
85+
86+
# compare outputs
87+
output = model(input)
88+
assert torch.allclose(true_output, output, atol=1e-5, rtol=0.0)
89+
90+
5991
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
6092
@pytest.mark.parametrize("randomized", (True, False))
6193
def test_correctness_model(type, randomized, model_apply, offload=False):

0 commit comments

Comments
 (0)