Skip to content

Commit 4085613

Browse files
committed
skip offloading tests until transformers changes land
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 97345b0 commit 4085613

File tree

3 files changed

+26
-22
lines changed

3 files changed

+26
-22
lines changed

tests/test_transform/factory/test_correctness.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626

2727

2828
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
29-
@pytest.mark.parametrize("randomized", (True, False))
30-
def test_correctness_linear(type, randomized):
29+
@pytest.mark.parametrize("randomize", (True, False))
30+
def test_correctness_linear(type, randomize):
3131
size = (4, 8)
3232
module = torch.nn.Linear(*size, bias=True)
33-
scheme = TransformScheme(type=type, randomized=randomized)
33+
scheme = TransformScheme(type=type, randomize=randomize)
3434
factory = TransformFactory.from_scheme(scheme, name="")
3535

3636
input_tfm = factory.create_transform(
@@ -55,8 +55,8 @@ def test_correctness_linear(type, randomized):
5555

5656

5757
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
58-
@pytest.mark.parametrize("randomized", (True, False))
59-
def test_correctness_model(type, randomized, model_apply, offload=False):
58+
@pytest.mark.parametrize("randomize", (True, False))
59+
def test_correctness_model(type, randomize, model_apply, offload=False):
6060
# load model
6161
model = model_apply[0]
6262
if offload:
@@ -71,7 +71,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
7171
# apply transforms
7272
config = TransformConfig(
7373
config_groups={
74-
"": TransformScheme(type=type, randomized=randomized, apply=model_apply[1])
74+
"": TransformScheme(type=type, randomize=randomize, apply=model_apply[1])
7575
}
7676
)
7777
apply_transform_config(model, config)
@@ -84,6 +84,6 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
8484
@requires_gpu
8585
@requires_accelerate()
8686
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
87-
@pytest.mark.parametrize("randomized", (True, False))
88-
def test_correctness_model_offload(type, randomized, model_apply):
89-
test_correctness_model(type, randomized, model_apply, offload=True)
87+
@pytest.mark.parametrize("randomize", (True, False))
88+
def test_correctness_model_offload(type, randomize, model_apply):
89+
test_correctness_model(type, randomize, model_apply, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030

3131
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
32-
@pytest.mark.parametrize("randomized", (True, False))
32+
@pytest.mark.parametrize("randomize", (True, False))
3333
@pytest.mark.parametrize("requires_grad", (True, False))
34-
def test_memory_sharing(type, randomized, requires_grad, offload=False):
34+
def test_memory_sharing(type, randomize, requires_grad, offload=False):
3535
# load model (maybe with offloading)
3636
model = TransformableModel(2, 2, 4, 4, 8, 8)
3737
if offload:
@@ -42,7 +42,7 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
4242
config_groups={
4343
"": TransformScheme(
4444
type=type,
45-
randomzied=randomized,
45+
randomzied=randomize,
4646
requires_grad=requires_grad,
4747
apply=[
4848
TransformArgs(targets="Linear", location="input"),
@@ -84,9 +84,9 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
8484
@requires_gpu
8585
@requires_accelerate()
8686
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
87-
@pytest.mark.parametrize("randomized", (True, False))
87+
@pytest.mark.parametrize("randomize", (True, False))
8888
def test_memory_sharing_offload(
8989
type,
90-
randomized,
90+
randomize,
9191
):
92-
test_memory_sharing(type, randomized, requires_grad=False, offload=True)
92+
test_memory_sharing(type, randomize, requires_grad=False, offload=True)

tests/test_transform/factory/test_serialization.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@
2020
apply_transform_config,
2121
)
2222
from compressed_tensors.utils import offloaded_dispatch
23-
from tests.test_transform.conftest import scheme_kwargs
2423
from tests.testing_utils import requires_accelerate, requires_gpu
2524

2625

27-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
28-
def test_serialization(scheme_kwargs, model_apply, tmp_path, offload=False):
26+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
27+
@pytest.mark.parametrize("randomize", (True, False))
28+
def test_serialization(type, randomize, model_apply, tmp_path, offload=False):
2929
# get model, maybe offload
3030
model, apply = model_apply
3131
if offload:
3232
offloaded_dispatch(model, torch.device("cuda"))
3333

3434
# apply transforms to model
3535
config = TransformConfig(
36-
config_groups={"": TransformScheme(**scheme_kwargs, apply=apply)}
36+
config_groups={"": TransformScheme(type=type, randomize=randomize, apply=apply)}
3737
)
3838
apply_transform_config(model, config)
3939

@@ -43,8 +43,12 @@ def test_serialization(scheme_kwargs, model_apply, tmp_path, offload=False):
4343
# TODO: reload model
4444

4545

46+
@pytest.mark.skip(reason="Requires changes in upstream transformers")
47+
# https://github.com/huggingface/transformers/pull/39280
48+
# https://github.com/huggingface/transformers/pull/39263
4649
@requires_gpu
4750
@requires_accelerate()
48-
@pytest.mark.parametrize("scheme_kwargs", scheme_kwargs())
49-
def test_serialization_offload(scheme_kwargs, model_apply, tmp_path):
50-
test_serialization(scheme_kwargs, model_apply, tmp_path, offload=True)
51+
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
52+
@pytest.mark.parametrize("randomize", (True, False))
53+
def test_serialization_offload(type, randomize, model_apply, tmp_path):
54+
test_serialization(type, randomize, model_apply, tmp_path, offload=True)

0 commit comments

Comments
 (0)