|
22 | 22 | from array_api_extra._lib._utils._compat import device as get_device |
23 | 23 | from array_api_extra._lib._utils._typing import Array, Device |
24 | 24 |
|
| 25 | +# some xp backends are untyped |
25 | 26 | # mypy: disable-error-code=no-untyped-usage |
26 | 27 |
|
27 | 28 |
|
@@ -330,6 +331,74 @@ def test_xp(self, xp: ModuleType): |
330 | 331 | xp_assert_equal(kron(a, b, xp=xp), k) |
331 | 332 |
|
332 | 333 |
|
| 334 | +class TestNUnique: |
| 335 | + def test_simple(self, xp: ModuleType): |
| 336 | + a = xp.asarray([[1, 1], [0, 2], [2, 2]]) |
| 337 | + xp_assert_equal(nunique(a), xp.asarray(3)) |
| 338 | + |
| 339 | + def test_empty(self, xp: ModuleType): |
| 340 | + a = xp.asarray([]) |
| 341 | + xp_assert_equal(nunique(a), xp.asarray(0)) |
| 342 | + |
| 343 | + def test_device(self, xp: ModuleType, device: Device): |
| 344 | + a = xp.asarray(0.0, device=device) |
| 345 | + assert get_device(nunique(a)) == device |
| 346 | + |
| 347 | + def test_xp(self, xp: ModuleType): |
| 348 | + a = xp.asarray([[1, 1], [0, 2], [2, 2]]) |
| 349 | + xp_assert_equal(nunique(a, xp=xp), xp.asarray(3)) |
| 350 | + |
| 351 | + |
| 352 | +@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device") |
| 353 | +class TestPad: |
| 354 | + def test_simple(self, xp: ModuleType): |
| 355 | + a = xp.arange(1, 4) |
| 356 | + padded = pad(a, 2) |
| 357 | + xp_assert_equal(padded, xp.asarray([0, 0, 1, 2, 3, 0, 0])) |
| 358 | + |
| 359 | + def test_fill_value(self, xp: ModuleType): |
| 360 | + a = xp.arange(1, 4) |
| 361 | + padded = pad(a, 2, constant_values=42) |
| 362 | + xp_assert_equal(padded, xp.asarray([42, 42, 1, 2, 3, 42, 42])) |
| 363 | + |
| 364 | + def test_ndim(self, xp: ModuleType): |
| 365 | + a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4)) |
| 366 | + padded = pad(a, 2) |
| 367 | + assert padded.shape == (6, 7, 8) |
| 368 | + |
| 369 | + def test_mode_not_implemented(self, xp: ModuleType): |
| 370 | + a = xp.arange(3) |
| 371 | + with pytest.raises(NotImplementedError, match="Only `'constant'`"): |
| 372 | + pad(a, 2, mode="edge") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
| 373 | + |
| 374 | + def test_device(self, xp: ModuleType, device: Device): |
| 375 | + a = xp.asarray(0.0, device=device) |
| 376 | + assert get_device(pad(a, 2)) == device |
| 377 | + |
| 378 | + def test_xp(self, xp: ModuleType): |
| 379 | + padded = pad(xp.asarray(0), 1, xp=xp) |
| 380 | + xp_assert_equal(padded, xp.asarray(0)) |
| 381 | + |
| 382 | + def test_tuple_width(self, xp: ModuleType): |
| 383 | + a = xp.reshape(xp.arange(12), (3, 4)) |
| 384 | + padded = pad(a, (1, 0)) |
| 385 | + assert padded.shape == (4, 5) |
| 386 | + |
| 387 | + padded = pad(a, (1, 2)) |
| 388 | + assert padded.shape == (6, 7) |
| 389 | + |
| 390 | + with pytest.raises((ValueError, RuntimeError)): |
| 391 | + pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] |
| 392 | + |
| 393 | + def test_list_of_tuples_width(self, xp: ModuleType): |
| 394 | + a = xp.reshape(xp.arange(12), (3, 4)) |
| 395 | + padded = pad(a, [(1, 0), (0, 2)]) |
| 396 | + assert padded.shape == (4, 6) |
| 397 | + |
| 398 | + padded = pad(a, [(1, 0), (0, 0)]) |
| 399 | + assert padded.shape == (4, 4) |
| 400 | + |
| 401 | + |
333 | 402 | @pytest.mark.skip_xp_backend(Backend.DASK_ARRAY, reason="no argsort") |
334 | 403 | @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no device") |
335 | 404 | class TestSetDiff1D: |
@@ -401,71 +470,3 @@ def test_device(self, xp: ModuleType, device: Device): |
401 | 470 |
|
402 | 471 | def test_xp(self, xp: ModuleType): |
403 | 472 | xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0)) |
404 | | - |
405 | | - |
406 | | -@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no arange, no device") |
407 | | -class TestPad: |
408 | | - def test_simple(self, xp: ModuleType): |
409 | | - a = xp.arange(1, 4) |
410 | | - padded = pad(a, 2) |
411 | | - xp_assert_equal(padded, xp.asarray([0, 0, 1, 2, 3, 0, 0])) |
412 | | - |
413 | | - def test_fill_value(self, xp: ModuleType): |
414 | | - a = xp.arange(1, 4) |
415 | | - padded = pad(a, 2, constant_values=42) |
416 | | - xp_assert_equal(padded, xp.asarray([42, 42, 1, 2, 3, 42, 42])) |
417 | | - |
418 | | - def test_ndim(self, xp: ModuleType): |
419 | | - a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4)) |
420 | | - padded = pad(a, 2) |
421 | | - assert padded.shape == (6, 7, 8) |
422 | | - |
423 | | - def test_mode_not_implemented(self, xp: ModuleType): |
424 | | - a = xp.arange(3) |
425 | | - with pytest.raises(NotImplementedError, match="Only `'constant'`"): |
426 | | - pad(a, 2, mode="edge") # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
427 | | - |
428 | | - def test_device(self, xp: ModuleType, device: Device): |
429 | | - a = xp.asarray(0.0, device=device) |
430 | | - assert get_device(pad(a, 2)) == device |
431 | | - |
432 | | - def test_xp(self, xp: ModuleType): |
433 | | - padded = pad(xp.asarray(0), 1, xp=xp) |
434 | | - xp_assert_equal(padded, xp.asarray(0)) |
435 | | - |
436 | | - def test_tuple_width(self, xp: ModuleType): |
437 | | - a = xp.reshape(xp.arange(12), (3, 4)) |
438 | | - padded = pad(a, (1, 0)) |
439 | | - assert padded.shape == (4, 5) |
440 | | - |
441 | | - padded = pad(a, (1, 2)) |
442 | | - assert padded.shape == (6, 7) |
443 | | - |
444 | | - with pytest.raises((ValueError, RuntimeError)): |
445 | | - pad(a, [(1, 2, 3)]) # type: ignore[list-item] # pyright: ignore[reportArgumentType] |
446 | | - |
447 | | - def test_list_of_tuples_width(self, xp: ModuleType): |
448 | | - a = xp.reshape(xp.arange(12), (3, 4)) |
449 | | - padded = pad(a, [(1, 0), (0, 2)]) |
450 | | - assert padded.shape == (4, 6) |
451 | | - |
452 | | - padded = pad(a, [(1, 0), (0, 0)]) |
453 | | - assert padded.shape == (4, 4) |
454 | | - |
455 | | - |
456 | | -class TestNUnique: |
457 | | - def test_simple(self, xp: ModuleType): |
458 | | - a = xp.asarray([[1, 1], [0, 2], [2, 2]]) |
459 | | - xp_assert_equal(nunique(a), xp.asarray(3)) |
460 | | - |
461 | | - def test_empty(self, xp: ModuleType): |
462 | | - a = xp.asarray([]) |
463 | | - xp_assert_equal(nunique(a), xp.asarray(0)) |
464 | | - |
465 | | - def test_device(self, xp: ModuleType, device: Device): |
466 | | - a = xp.asarray(0.0, device=device) |
467 | | - assert get_device(nunique(a)) == device |
468 | | - |
469 | | - def test_xp(self, xp: ModuleType): |
470 | | - a = xp.asarray([[1, 1], [0, 2], [2, 2]]) |
471 | | - xp_assert_equal(nunique(a, xp=xp), xp.asarray(3)) |
|
0 commit comments