|
31 | 31 | ) |
32 | 32 |
|
33 | 33 |
|
34 | | -def _gen_randomized_3d_data(low=1, high=32, dtype=np.float32): |
| 34 | +def _gen_randomized_3d_data(low=16, high=32, dtype=np.float32): |
35 | 35 | """Helper function to generate randomized 3d data for summary modules, min and |
36 | 36 | max dimensions for each axis are given by ``low`` and ``high``.""" |
37 | 37 |
|
@@ -182,7 +182,7 @@ def test_set_transformer(summary_dim, num_seeds, num_attention_blocks, num_induc |
182 | 182 |
|
183 | 183 |
|
184 | 184 | @pytest.mark.parametrize("summary_dim", [2, 9]) |
185 | | -@pytest.mark.parametrize("template_dim", [4, 8]) |
| 185 | +@pytest.mark.parametrize("template_dim", [32, 64]) |
186 | 186 | @pytest.mark.parametrize("num_attention_blocks", [1, 2]) |
187 | 187 | def test_time_series_transformer(summary_dim, template_dim, num_attention_blocks): |
188 | 188 | """Tests the fidelity of the ``TimeSeriesTransformer`` w.r.t. shape integrity |
@@ -216,7 +216,7 @@ def test_time_series_transformer(summary_dim, template_dim, num_attention_blocks |
216 | 216 | assert len(transformer.attention_blocks.layers) == num_attention_blocks |
217 | 217 |
|
218 | 218 | # Test non-permutation invariant |
219 | | - assert not np.allclose(out, out_perm, atol=1e-5) |
| 219 | + assert not np.allclose(out, out_perm, atol=1e-6) |
220 | 220 |
|
221 | 221 |
|
222 | 222 | @pytest.mark.parametrize("summary_dim", [13, 2]) |
|
0 commit comments