Skip to content

[BUG] Investigate flaky JAX tests #1141

@shi-eric

Description

@shi-eric

Bug Description

The following multi-device tests have been disabled in CI/CD (skipped) due to occasional failures in which the data expected to be computed by the second device is actually zero at the time of checking

  • test_ffi_jax_callable_pmap_multi_stage
  • test_ffi_jax_callable_pmap_multi_output
  • test_ffi_jax_callable_pmap_mul

Example:

======================================================================
test_ffi_jax_callable_pmap_multi_output (warp.tests.interop.test_jax.TestJax.test_ffi_jax_callable_pmap_multi_output)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/builds/omniverse/warp/warp/tests/unittest_utils.py", line 256, in test_func
    func(self, device, **kwargs)
  File "/builds/omniverse/warp/warp/tests/interop/test_jax.py", line 998, in test_ffi_jax_callable_pmap_multi_output
    assert_np_equal(np.asarray(out), ref)
  File "/builds/omniverse/warp/warp/tests/unittest_utils.py", line 247, in assert_np_equal
    np.testing.assert_array_equal(result, expect)
  File "/builds/omniverse/warp/.venv/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 1121, in assert_array_equal
    assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg,
  File "/builds/omniverse/warp/.venv/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 983, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal
Mismatched elements: 524288 / 1048576 (50%)
First 5 mismatches are at indices:
 [1, 0]: 0.0 (ACTUAL), 2097153.0 (DESIRED)
 [1, 1]: 0.0 (ACTUAL), 2097157.0 (DESIRED)
 [1, 2]: 0.0 (ACTUAL), 2097161.0 (DESIRED)
 [1, 3]: 0.0 (ACTUAL), 2097165.0 (DESIRED)
 [1, 4]: 0.0 (ACTUAL), 2097169.0 (DESIRED)
Max absolute difference among violations: 4.194301e+06
Max relative difference among violations: 1.
 ACTUAL: array([[1.000000e+00, 5.000000e+00, 9.000000e+00, ..., 2.097141e+06,
        2.097145e+06, 2.097149e+06],
       [0.000000e+00, 0.000000e+00, 0.000000e+00, ..., 0.000000e+00,
        0.000000e+00, 0.000000e+00]], shape=(2, 524288), dtype=float32)
 DESIRED: array([[1.000000e+00, 5.000000e+00, 9.000000e+00, ..., 2.097141e+06,
        2.097145e+06, 2.097149e+06],
       [2.097153e+06, 2.097157e+06, 2.097161e+06, ..., 4.194293e+06,
        4.194297e+06, 4.194301e+06]], shape=(2, 524288), dtype=float32)

Additionally, test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output and test_ffi_jax_kernel_autodiff_jit_of_grad_simple are single GPU test that have failed on cuda:1:

Traceback (most recent call last):
  File "/builds/omniverse/warp/warp/tests/unittest_utils.py", line 256, in test_func
    func(self, device, **kwargs)
  File "/builds/omniverse/warp/warp/tests/interop/test_jax.py", line 1258, in test_ffi_jax_kernel_autodiff_jit_of_grad_multi_output
    assert_np_equal(np.asarray(da), ref_da)
  File "/builds/omniverse/warp/warp/tests/unittest_utils.py", line 247, in assert_np_equal
    np.testing.assert_array_equal(result, expect)
  File "/builds/omniverse/warp/.venv/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 1121, in assert_array_equal
    assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg,
  File "/builds/omniverse/warp/.venv/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 983, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal
Mismatched elements: 1048576 / 1048576 (100%)
First 5 mismatches are at indices:
 [0]: 0.0 (ACTUAL), 2.0 (DESIRED)
 [1]: 0.0 (ACTUAL), 4.0 (DESIRED)
 [2]: 0.0 (ACTUAL), 6.0 (DESIRED)
 [3]: 0.0 (ACTUAL), 8.0 (DESIRED)
 [4]: 0.0 (ACTUAL), 10.0 (DESIRED)
Max absolute difference among violations: 2.097152e+06
Max relative difference among violations: 1.
 ACTUAL: array([0., 0., 0., ..., 0., 0., 0.], shape=(1048576,), dtype=float32)
 DESIRED: array([2.000000e+00, 4.000000e+00, 6.000000e+00, ..., 2.097148e+06,
       2.097150e+06, 2.097152e+06], shape=(1048576,), dtype=float32)

System Information

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinginteropInteroperability of Warp with other libraries

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions