|
57 | 57 | WithModuleWithStringArg, |
58 | 58 | WithModuleWithStringKwarg, |
59 | 59 | WithModuleWithStringOutput, |
| 60 | + WithMultiHeadAttention, |
60 | 61 | WithNoTensorOutput, |
61 | 62 | WithRNN, |
62 | 63 | WithSideEffect, |
63 | 64 | WithSomeFrozenModule, |
| 65 | + WithTransformer, |
| 66 | + WithTransformerLarge, |
64 | 67 | ) |
65 | 68 | from utils.dict_assertions import assert_tensor_dicts_are_close |
66 | 69 | from utils.forward_backwards import ( |
|
118 | 121 | (WithModuleWithStringOutput, 32), |
119 | 122 | (WithModuleWithStringKwarg, 32), |
120 | 123 | (WithModuleWithHybridPyTreeKwarg, 32), |
| 124 | + (WithMultiHeadAttention, 32), |
| 125 | + param(WithTransformer, 32, marks=mark.filterwarnings("ignore:There is a performance drop")), |
121 | 126 | (FreeParam, 32), |
122 | 127 | (NoFreeParam, 32), |
123 | 128 | param(Cifar10Model, 16, marks=mark.slow), |
|
126 | 131 | param(GroupNormMobileNetV3Small, 3, marks=mark.slow), |
127 | 132 | param(SqueezeNet, 8, marks=mark.slow), |
128 | 133 | param(InstanceNormMobileNetV2, 2, marks=mark.slow), |
| 134 | + param( |
| 135 | + WithTransformerLarge, |
| 136 | + 8, |
| 137 | + marks=[mark.slow, mark.filterwarnings("ignore:There is a performance drop")], |
| 138 | + ), |
129 | 139 | ] |
130 | 140 |
|
131 | 141 |
|
@@ -565,3 +575,42 @@ def test_batched_non_batched_equivalence(shape: list[int], batch_dim: int): |
565 | 575 | gramian2 = engine2.compute_gramian(output) |
566 | 576 |
|
567 | 577 | assert_close(gramian1, gramian2) |
| 578 | + |
| 579 | + |
| 580 | +@mark.parametrize(["architecture", "batch_size"], PARAMETRIZATIONS) |
| 581 | +def test_batched_non_batched_equivalence_2(architecture: ShapedModule, batch_size: int): |
| 582 | + """ |
| 583 | + Same as test_batched_non_batched_equivalence but on real architectures, and thus only between |
| 584 | + batch_size=0 and batch_size=None. |
| 585 | +
|
| 586 | + If for some architecture this test passes but the test_compute_gramian doesn't pass, it could be |
| 587 | + that the get_used_params does not work for some module of the architecture. |
| 588 | + """ |
| 589 | + |
| 590 | + input_shapes = architecture.INPUT_SHAPES |
| 591 | + output_shapes = architecture.OUTPUT_SHAPES |
| 592 | + |
| 593 | + torch.manual_seed(0) |
| 594 | + model_0 = architecture().to(device=DEVICE) |
| 595 | + torch.manual_seed(0) |
| 596 | + model_none = architecture().to(device=DEVICE) |
| 597 | + |
| 598 | + engine_0 = Engine(model_0.modules(), batch_dim=0) |
| 599 | + engine_none = Engine(model_none.modules(), batch_dim=None) |
| 600 | + |
| 601 | + inputs = make_tensors(batch_size, input_shapes) |
| 602 | + targets = make_tensors(batch_size, output_shapes) |
| 603 | + loss_fn = make_mse_loss_fn(targets) |
| 604 | + |
| 605 | + torch.random.manual_seed(0) # Fix randomness for random models |
| 606 | + output = model_0(inputs) |
| 607 | + losses_0 = reduce_to_vector(loss_fn(output)) |
| 608 | + |
| 609 | + torch.random.manual_seed(0) # Fix randomness for random models |
| 610 | + output = model_none(inputs) |
| 611 | + losses_none = reduce_to_vector(loss_fn(output)) |
| 612 | + |
| 613 | + gramian_0 = engine_0.compute_gramian(losses_0) |
| 614 | + gramian_none = engine_none.compute_gramian(losses_none) |
| 615 | + |
| 616 | + assert_close(gramian_0, gramian_none, rtol=1e-4, atol=1e-5) |
0 commit comments