Skip to content

Commit 35395fc

Browse files
committed
fix tests
Signed-off-by: Kyle Sayers <[email protected]>
1 parent c991db3 commit 35395fc

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

tests/test_transform/factory/test_correctness.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626

2727

2828
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
29-
@pytest.mark.parametrize("randomized", (True, False))
30-
def test_correctness_linear(type, randomized):
29+
@pytest.mark.parametrize("randomize", (True, False))
30+
def test_correctness_linear(type, randomize):
3131
size = (4, 8)
3232
module = torch.nn.Linear(*size, bias=True)
33-
scheme = TransformScheme(type=type, randomized=randomized)
33+
scheme = TransformScheme(type=type, randomize=randomize)
3434
factory = TransformFactory.from_scheme(scheme, name="")
3535

3636
input_tfm = factory.create_transform(
@@ -55,8 +55,8 @@ def test_correctness_linear(type, randomized):
5555

5656

5757
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
58-
@pytest.mark.parametrize("randomized", (True, False))
59-
def test_correctness_model(type, randomized, model_apply, offload=False):
58+
@pytest.mark.parametrize("randomize", (True, False))
59+
def test_correctness_model(type, randomize, model_apply, offload=False):
6060
# load model
6161
model = model_apply[0]
6262
if offload:
@@ -71,7 +71,7 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
7171
# apply transforms
7272
config = TransformConfig(
7373
config_groups={
74-
"": TransformScheme(type=type, randomized=randomized, apply=model_apply[1])
74+
"": TransformScheme(type=type, randomize=randomize, apply=model_apply[1])
7575
}
7676
)
7777
apply_transform_config(model, config)
@@ -84,6 +84,6 @@ def test_correctness_model(type, randomized, model_apply, offload=False):
8484
@requires_gpu
8585
@requires_accelerate()
8686
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
87-
@pytest.mark.parametrize("randomized", (True, False))
88-
def test_correctness_model_offload(type, randomized, model_apply):
89-
test_correctness_model(type, randomized, model_apply, offload=True)
87+
@pytest.mark.parametrize("randomize", (True, False))
88+
def test_correctness_model_offload(type, randomize, model_apply):
89+
test_correctness_model(type, randomize, model_apply, offload=True)

tests/test_transform/factory/test_memory.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929

3030

3131
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
32-
@pytest.mark.parametrize("randomized", (True, False))
32+
@pytest.mark.parametrize("randomize", (True, False))
3333
@pytest.mark.parametrize("requires_grad", (True, False))
34-
def test_memory_sharing(type, randomized, requires_grad, offload=False):
34+
def test_memory_sharing(type, randomize, requires_grad, offload=False):
3535
# load model (maybe with offloading)
3636
model = TransformableModel(2, 2, 4, 4, 8, 8)
3737
if offload:
@@ -42,7 +42,7 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
4242
config_groups={
4343
"": TransformScheme(
4444
type=type,
45-
randomzied=randomized,
45+
randomize=randomize,
4646
requires_grad=requires_grad,
4747
apply=[
4848
TransformArgs(targets="Linear", location="input"),
@@ -84,9 +84,6 @@ def test_memory_sharing(type, randomized, requires_grad, offload=False):
8484
@requires_gpu
8585
@requires_accelerate()
8686
@pytest.mark.parametrize("type", ("hadamard", "random-hadamard"))
87-
@pytest.mark.parametrize("randomized", (True, False))
88-
def test_memory_sharing_offload(
89-
type,
90-
randomized,
91-
):
92-
test_memory_sharing(type, randomized, requires_grad=False, offload=True)
87+
@pytest.mark.parametrize("randomize", (True, False))
88+
def test_memory_sharing_offload(type, randomize):
89+
test_memory_sharing(type, randomize, requires_grad=False, offload=True)

0 commit comments

Comments
 (0)