Skip to content

Commit 2b073b6

Browse files
Fix torch gpu CI (#20696)
1 parent be1191f commit 2b073b6

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ then
7474

7575
# TODO: keras/src/export/export_lib_test.py update LD_LIBRARY_PATH
7676
pytest keras --ignore keras/src/applications \
77-
--ignore keras/src/export/export_lib_test.py \
7877
--cov=keras \
7978
--cov-config=pyproject.toml
8079

keras/src/export/export_lib_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
5757
),
5858
)
5959
@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI")
60+
@pytest.mark.skipif(
61+
testing.torch_uses_gpu(), reason="Leads to core dumps on CI"
62+
)
6063
class ExportSavedModelTest(testing.TestCase):
6164
@parameterized.named_parameters(
6265
named_product(model_type=["sequential", "functional", "subclass"])
@@ -344,6 +347,9 @@ def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs):
344347
reason="Export only currently supports the TF and JAX backends.",
345348
)
346349
@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI")
350+
@pytest.mark.skipif(
351+
testing.torch_uses_gpu(), reason="Leads to core dumps on CI"
352+
)
347353
class ExportArchiveTest(testing.TestCase):
348354
@parameterized.named_parameters(
349355
named_product(model_type=["sequential", "functional", "subclass"])

keras/src/models/model_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,9 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
12291229
@pytest.mark.skipif(
12301230
testing.jax_uses_gpu(), reason="Leads to core dumps on CI"
12311231
)
1232+
@pytest.mark.skipif(
1233+
testing.torch_uses_gpu(), reason="Leads to core dumps on CI"
1234+
)
12321235
def test_export(self):
12331236
import tensorflow as tf
12341237

0 commit comments

Comments
 (0)