Skip to content

Commit bcc8498

Browse files
rootgulsumgudukbay
authored andcommitted
fix is_device_rocm function call
1 parent 9890a51 commit bcc8498

20 files changed

+48
-48
lines changed

tests/debugging_primitives_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ def f2(x):
806806
self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")
807807

808808
def test_unordered_print_with_pjit(self):
809-
if jtu.is_device_rocm:
809+
if jtu.is_device_rocm():
810810
self.skipTest("Skip on ROCm: tests/debugging_primitives_test.py::DebugPrintParallelTest::test_unordered_print_with_pjit")
811811
def f(x):
812812
debug_print("{}", x, ordered=False)
@@ -843,7 +843,7 @@ def f(x):
843843
self.assertEqual(output(), "[0 1 2 3 4 5 6 7]\n")
844844

845845
def test_unordered_print_of_pjit_of_while(self):
846-
if jtu.is_device_rocm:
846+
if jtu.is_device_rocm():
847847
self.skipTest("Skip on ROCm: tests/debugging_primitives_test.py::DebugPrintParallelTest::test_unordered_print_of_pjit_of_while")
848848
def f(x):
849849
def cond(carry):

tests/export_harnesses_multi_platform_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def setUp(self):
7979
"operator to work around missing XLA support for pair-reductions")
8080
)
8181
def test_prim(self, harness: test_harnesses.Harness):
82-
if jtu.is_device_rocm and "multi_array_shape" in harness.fullname:
82+
if jtu.is_device_rocm() and "multi_array_shape" in harness.fullname:
8383
self.skipTest("Skip on ROCm: test_prim_multi_array_shape")
8484

8585
if "eigh_" in harness.fullname:
@@ -194,7 +194,7 @@ def test_all_gather(self, *, dtype):
194194
self.export_and_compare_to_native(f, x)
195195

196196
def test_random_with_threefry_gpu_kernel_lowering(self):
197-
if jtu.is_device_rocm and jtu.get_rocm_version() > (6, 5):
197+
if jtu.is_device_rocm() and jtu.get_rocm_version() > (6, 5):
198198
self.skipTest("Skip on ROCm: test_random_with_threefry_gpu_kernel_lowering")
199199

200200
# On GPU we use a custom call for threefry2x32

tests/ffi_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def fun(x):
173173
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
174174
@jtu.run_on_devices("gpu", "cpu")
175175
def test_ffi_call(self, shape):
176-
if jtu.is_device_rocm and str(self).split()[0] == "test_ffi_call0":
176+
if jtu.is_device_rocm() and str(self).split()[0] == "test_ffi_call0":
177177
self.skipTest("Skip on ROCm: test_ffi_call0")
178178

179179
x = self.rng().randn(*shape).astype(np.float32)
@@ -189,7 +189,7 @@ def test_ffi_call(self, shape):
189189
)
190190
@jtu.run_on_devices("gpu", "cpu")
191191
def test_ffi_call_batching(self, shape, vmap_method):
192-
if jtu.is_device_rocm:
192+
if jtu.is_device_rocm():
193193
self.skipTest("Skip on ROCm: test_ffi_call_batching")
194194

195195
shape = (10,) + shape
@@ -276,7 +276,7 @@ def test_invalid_result_type(self):
276276

277277
@jtu.run_on_devices("gpu", "cpu")
278278
def test_shard_map(self):
279-
if jtu.is_device_rocm:
279+
if jtu.is_device_rocm():
280280
self.skipTest("Skip on ROCm: tests/ffi_test.py::FfiTest::test_shard_map")
281281
mesh = jtu.create_mesh((len(jax.devices()),), ("i",))
282282
x = self.rng().randn(8, 4, 5).astype(np.float32)

tests/lax_control_flow_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2599,7 +2599,7 @@ def testAssociativeScanOfBools(self):
25992599
def testAssociativeScanSolvingRegressionTest(self, shape):
26002600
# This test checks that the batching rule doesn't raise for a batch
26012601
# sensitive function (solve).
2602-
if jtu.is_device_rocm:
2602+
if jtu.is_device_rocm():
26032603
self.skipTest("Skip on ROCm: testAssociativeScanSolvingRegressionTest")
26042604

26052605
ms = np.repeat(np.eye(2).reshape(1, 2, 2), shape, axis=0)

tests/lax_numpy_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,10 +1562,10 @@ def testTrimZerosNotOneDArray(self):
15621562
)
15631563
@jax.default_matmul_precision("float32")
15641564
def testPoly(self, a_shape, dtype, rank):
1565-
if jtu.is_device_rocm and a_shape == (12,) and dtype in ( np.int32, np.int8 ) and rank == 2:
1565+
if jtu.is_device_rocm() and a_shape == (12,) and dtype in ( np.int32, np.int8 ) and rank == 2:
15661566
self.skipTest(f"Skip on ROCm: testPoly: a_shape == (12,) and dtype == {dtype} and rank == 2")
15671567

1568-
if jtu.is_device_rocm and a_shape == (6,) and dtype == np.float32 and rank == 2:
1568+
if jtu.is_device_rocm() and a_shape == (6,) and dtype == np.float32 and rank == 2:
15691569
self.skipTest("Skip on ROCm: testPoly: a_shape == (6,) and dtype == numpy.float32 and rank == 2")
15701570

15711571
if dtype in (np.float16, jnp.bfloat16, np.int16):
@@ -2585,7 +2585,7 @@ def testDiagFlat(self, shape, dtype, k):
25852585
a2_shape=one_dim_array_shapes,
25862586
)
25872587
def testPolyMul(self, a1_shape, a2_shape, dtype):
2588-
if jtu.is_device_rocm and str(self).split()[0] == "testPolyMul1":
2588+
if jtu.is_device_rocm() and str(self).split()[0] == "testPolyMul1":
25892589
self.skipTest("Skip on ROCm: testPolyMul")
25902590

25912591
rng = jtu.rand_default(self.rng())

tests/linalg_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ def testTensordot(self, lhs_shape, rhs_shape, axes, dtype):
874874
@jax.default_matmul_precision("float32")
875875
def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorithm):
876876
if algorithm is not None:
877-
if jtu.is_device_rocm and algorithm == lax.linalg.SvdAlgorithm.JACOBI and dtype in {np.float32, np.complex64}:
877+
if jtu.is_device_rocm() and algorithm == lax.linalg.SvdAlgorithm.JACOBI and dtype in {np.float32, np.complex64}:
878878
self.skipTest("Skip on ROCm: testSVD Jacobi tests")
879879
if hermitian:
880880
self.skipTest("Hermitian SVD doesn't support the algorithm parameter.")
@@ -1006,7 +1006,7 @@ def testNumpyQrModes(self, shape, dtype, mode):
10061006
)
10071007
@jax.default_matmul_precision("float32")
10081008
def testQr(self, shape, dtype, full_matrices):
1009-
if jtu.is_device_rocm:
1009+
if jtu.is_device_rocm():
10101010
self.skipTest("Skip on ROCm: tests/linalg_test.py::NumpyLinalgTest::testQr")
10111011

10121012
if (jtu.test_device_matches(["cuda"]) and
@@ -1084,7 +1084,7 @@ def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16):
10841084
dtype=float_types + complex_types,
10851085
)
10861086
def testQrBatching(self, shape, dtype):
1087-
if jtu.is_device_rocm:
1087+
if jtu.is_device_rocm():
10881088
self.skipTest("Skip on ROCm: tests/linalg_test.py::NumpyLinalgTest::testQrBatching")
10891089

10901090
rng = jtu.rand_default(self.rng())

tests/multiprocess_gpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_gpu_distributed_initialize(self):
9393

9494
def test_distributed_jax_visible_devices(self):
9595
"""Test jax_visible_devices works in distributed settings."""
96-
if jtu.is_device_rocm:
96+
if jtu.is_device_rocm():
9797
self.skipTest("Skip on ROCm: test_distributed_jax_visible_devices")
9898
if not jtu.test_device_matches(['gpu']):
9999
raise unittest.SkipTest('Tests only for GPU.')

tests/nn_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def fwd(a, b, is_ref=False):
201201
impl=['cudnn', 'xla'],
202202
)
203203
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
204-
if jtu.is_device_rocm and dtype == jnp.float16 and group_num == 4 and impl == 'xla':
204+
if jtu.is_device_rocm() and dtype == jnp.float16 and group_num == 4 and impl == 'xla':
205205
self.skipTest("Skip on ROCm: testDotProductAttention[21,23]")
206206
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904):
207207
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")

tests/pallas/gpu_attention_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def test_mqa(
104104
kv_seq_len,
105105
return_residuals,
106106
):
107-
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
107+
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
108108
self.skipTest("Skip on ROCm: test_mqa: LLVM ERROR: Do not know how to scalarize the result of this operator!")
109109
del kwargs
110110
normalize_output = not return_residuals
@@ -185,7 +185,7 @@ def test_gqa(
185185
kv_seq_len,
186186
return_residuals,
187187
):
188-
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
188+
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
189189
self.skipTest("Skip on ROCm: test_gqa: LLVM ERROR: Do not know how to scalarize the result of this operator!")
190190
del kwargs
191191
normalize_output = not return_residuals

tests/pallas/gpu_ops_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ def test_fused_attention_fwd(
175175
use_fwd,
176176
use_segment_ids,
177177
):
178-
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
178+
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
179179
self.skipTest("Skip on ROCm: test_fused_attention_fwd: LLVM ERROR: Do not know how to scalarize the result of this operator!")
180180

181-
if jtu.is_device_rocm and batch_size == 2 and seq_len == 384 and num_heads == 8 and head_dim == 64 and block_sizes == (('block_q', 128), ('block_k', 128)) and causal and use_fwd and use_segment_ids:
181+
if jtu.is_device_rocm() and batch_size == 2 and seq_len == 384 and num_heads == 8 and head_dim == 64 and block_sizes == (('block_q', 128), ('block_k', 128)) and causal and use_fwd and use_segment_ids:
182182
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_fwd4")
183183
k1, k2, k3 = random.split(random.key(0), 3)
184184
q = random.normal(
@@ -232,7 +232,7 @@ def impl(q, k, v):
232232
# RESOURCE_EXHAUSTED: Shared memory size limit exceeded" error.
233233
@jtu.sample_product(
234234
batch_size=(1, 2),
235-
seq_len=(32, 64) if jtu.is_device_rocm else (128, 384),
235+
seq_len=(32, 64) if jtu.is_device_rocm() else (128, 384),
236236
num_heads=(1, 2),
237237
head_dim=(32, 64, 128,),
238238
block_sizes=(
@@ -253,7 +253,7 @@ def impl(q, k, v):
253253
("block_kv_dq", 32),
254254
),
255255
)
256-
if jtu.is_device_rocm else (
256+
if jtu.is_device_rocm() else (
257257
(
258258
("block_q", 128),
259259
("block_k", 128),
@@ -295,9 +295,9 @@ def test_fused_attention_bwd(
295295
):
296296
test_name = str(self).split()[0]
297297
skip_suffix_list = [4, 6, 7, 8, 9]
298-
if jtu.is_device_rocm and 'gfx950' in [d.compute_capability for d in jax.devices()]:
298+
if jtu.is_device_rocm() and 'gfx950' in [d.compute_capability for d in jax.devices()]:
299299
self.skipTest("Skip on ROCm: test_fused_attention_bwd: LLVM ERROR: Do not know how to scalarize the result of this operator!")
300-
if jtu.is_device_rocm and self.INTERPRET and any(test_name.endswith(str(suffix)) for suffix in skip_suffix_list):
300+
if jtu.is_device_rocm() and self.INTERPRET and any(test_name.endswith(str(suffix)) for suffix in skip_suffix_list):
301301
self.skipTest("Skip on ROCm: tests/pallas/gpu_ops_test.py::FusedAttentionTest::test_fused_attention_bwd[4, 6, 7, 8, 9]")
302302

303303
k1, k2, k3 = random.split(random.key(0), 3)

0 commit comments

Comments
 (0)