|
1 | 1 | using NNlib, Reactant, Enzyme
|
| 2 | +using Statistics |
2 | 3 |
|
3 | 4 | @testset "Activation Functions" begin
|
4 | 5 | sumabs2(f, x) = sum(abs2, f.(x))
|
@@ -381,6 +382,255 @@ end
|
381 | 382 | end
|
382 | 383 | end
|
383 | 384 |
|
| 385 | +# Adapted from https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/test/testsuite/scatter.jl#L108 |
| 386 | +@testset "NNlib scatter" begin |
| 387 | + function test_scatter(dsts, srcs, idxs, res; dims) |
| 388 | + @testset "scatter Float32 $op" for op in (+, -, max, min, *, /, mean) |
| 389 | + for idx in values(idxs), dim in dims |
| 390 | + dst = copy(dsts[dim]) |
| 391 | + target_y = res[(op, dim, true)] |
| 392 | + src = srcs[(dim, true)] |
| 393 | + if op == / |
| 394 | + src = src .* 2.0f0 |
| 395 | + end |
| 396 | + |
| 397 | + y1 = @jit( |
| 398 | + NNlib.scatter!( |
| 399 | + op, Reactant.to_rarray(dst), Reactant.to_rarray(src), idx |
| 400 | + ) |
| 401 | + ) |
| 402 | + @test y1 ≈ target_y |
| 403 | + @test y1 isa ConcreteRArray{Float32,ndims(dst)} |
| 404 | + @test size(y1) == size(dsts[dim]) |
| 405 | + dst = copy(dsts[dim]) |
| 406 | + y2 = @jit( |
| 407 | + NNlib.scatter!( |
| 408 | + op, |
| 409 | + Reactant.to_rarray(dst), |
| 410 | + Reactant.to_rarray(src), |
| 411 | + Reactant.to_rarray(idx), |
| 412 | + ) |
| 413 | + ) |
| 414 | + @test y2 ≈ target_y |
| 415 | + @test y2 isa ConcreteRArray{Float32,ndims(dst)} |
| 416 | + @test size(y2) == size(dsts[dim]) |
| 417 | + |
| 418 | + target_y = res[(op, dim, false)] |
| 419 | + src = srcs[(dim, false)] |
| 420 | + if op == / |
| 421 | + src = src .* 2.0f0 |
| 422 | + end |
| 423 | + |
| 424 | + y3 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), idx)) |
| 425 | + @test y3 ≈ target_y |
| 426 | + @test y3 isa ConcreteRArray{Float32,ndims(dst)} |
| 427 | + @test size(y3) == size(dsts[dim]) |
| 428 | + y4 = @jit( |
| 429 | + NNlib.scatter( |
| 430 | + op, |
| 431 | + Reactant.to_rarray(src), |
| 432 | + Reactant.to_rarray(idx); |
| 433 | + dstsize=size(dsts[dim]), |
| 434 | + ) |
| 435 | + ) |
| 436 | + @test y4 ≈ target_y |
| 437 | + @test y4 isa ConcreteRArray{Float32,ndims(dst)} |
| 438 | + @test size(y4) == size(dsts[dim]) |
| 439 | + |
| 440 | + ridx = Reactant.to_rarray(idx) |
| 441 | + if ridx isa Reactant.AbstractConcreteArray |
| 442 | + @test_throws ArgumentError @jit( |
| 443 | + NNlib.scatter(op, Reactant.to_rarray(src), ridx) |
| 444 | + ) |
| 445 | + else |
| 446 | + y5 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), ridx)) |
| 447 | + @test y5 ≈ target_y |
| 448 | + @test y5 isa ConcreteRArray{Float32,ndims(dst)} |
| 449 | + @test size(y5) == size(dsts[dim]) |
| 450 | + end |
| 451 | + end |
| 452 | + end |
| 453 | + end |
| 454 | + |
| 455 | + @testset "scatter 1d src, 1d index => 1d output" begin |
| 456 | + #! format: off |
| 457 | + dsts = Dict( |
| 458 | + 0 => Float32[3, 4, 5, 6, 7] |
| 459 | + ) |
| 460 | + |
| 461 | + srcs = Dict( |
| 462 | + (0, true) => ones(Float32, 5), |
| 463 | + (0, false) => collect(Float32, 1:5), |
| 464 | + ) |
| 465 | + |
| 466 | + idxs = Dict( |
| 467 | + :int => [4, 2, 1, 5, 3], |
| 468 | + :tup => [(4,), (2,), (1,), (5,), (3,)], |
| 469 | + :car => CartesianIndex.([(4,), (2,), (1,), (5,), (3,)]), |
| 470 | + ) |
| 471 | + |
| 472 | + res = Dict( |
| 473 | + (+, 0, true) => Float32[4, 5, 6, 7, 8], |
| 474 | + (+, 0, false) => Float32[3, 2, 5, 1, 4], |
| 475 | + |
| 476 | + (-, 0, true) => Float32[2, 3, 4, 5, 6], |
| 477 | + (-, 0, false) => Float32[-3, -2, -5, -1, -4], |
| 478 | + |
| 479 | + (max, 0, true) => Float32[3, 4, 5, 6, 7], |
| 480 | + (max, 0, false) => Float32[3, 2, 5, 1, 4], |
| 481 | + |
| 482 | + (min, 0, true) => Float32[1, 1, 1, 1, 1], |
| 483 | + (min, 0, false) => Float32[3, 2, 5, 1, 4], |
| 484 | + |
| 485 | + (*, 0, true) => Float32[3, 4, 5, 6, 7], |
| 486 | + (*, 0, false) => Float32[3, 2, 5, 1, 4], |
| 487 | + |
| 488 | + (/, 0, true) => Float32[1.5, 2.0, 2.5, 3.0, 3.5], |
| 489 | + (/, 0, false) => Float32[1//6, 1//4, 1//10, 1//2, 1//8], |
| 490 | + |
| 491 | + (mean, 0, true) => Float32[4, 5, 6, 7, 8], |
| 492 | + (mean, 0, false) => Float32[3, 2, 5, 1, 4], |
| 493 | + ) |
| 494 | + #! format: on |
| 495 | + test_scatter(dsts, srcs, idxs, res; dims=[0]) |
| 496 | + end |
| 497 | + |
| 498 | + @testset "scatter 2d src, 1d index => 2d output" begin |
| 499 | + #! format: off |
| 500 | + dsts = Dict( |
| 501 | + 0 => Float32[3 3 4 4 5 |
| 502 | + 5 5 6 6 7] |
| 503 | + ) |
| 504 | + |
| 505 | + srcs = Dict( |
| 506 | + (0, true) => ones(Float32, 2, 5), |
| 507 | + (0, false) => ones(Float32, 2) * collect(1:5)', |
| 508 | + ) |
| 509 | + |
| 510 | + idxs = Dict( |
| 511 | + :int => [4, 2, 1, 5, 3], |
| 512 | + :tup => [(4,), (2,), (1,), (5,), (3,)], |
| 513 | + :car => CartesianIndex.([(4,), (2,), (1,), (5,), (3,)]), |
| 514 | + ) |
| 515 | + |
| 516 | + res = Dict( |
| 517 | + (+, 0, true) => Float32[4 4 5 5 6; |
| 518 | + 6 6 7 7 8], |
| 519 | + (+, 0, false) => Float32[3 2 5 1 4; |
| 520 | + 3 2 5 1 4], |
| 521 | + |
| 522 | + (-, 0, true) => Float32[2 2 3 3 4; |
| 523 | + 4 4 5 5 6], |
| 524 | + (-, 0, false) => Float32[-3 -2 -5 -1 -4; |
| 525 | + -3 -2 -5 -1 -4], |
| 526 | + |
| 527 | + (max, 0, true) => Float32[3 3 4 4 5; |
| 528 | + 5 5 6 6 7], |
| 529 | + (max, 0, false) => Float32[3 2 5 1 4; |
| 530 | + 3 2 5 1 4], |
| 531 | + |
| 532 | + (min, 0, true) => Float32[1 1 1 1 1; |
| 533 | + 1 1 1 1 1], |
| 534 | + (min, 0, false) => Float32[3 2 5 1 4; |
| 535 | + 3 2 5 1 4], |
| 536 | + |
| 537 | + (*, 0, true) => Float32[3 3 4 4 5; |
| 538 | + 5 5 6 6 7], |
| 539 | + (*, 0, false) => Float32[3 2 5 1 4; |
| 540 | + 3 2 5 1 4], |
| 541 | + |
| 542 | + (/, 0, true) => Float32[1.5 1.5 2.0 2.0 2.5; |
| 543 | + 2.5 2.5 3.0 3.0 3.5], |
| 544 | + (/, 0, false) => Float32[1//6 1//4 1//10 1//2 1//8; |
| 545 | + 1//6 1//4 1//10 1//2 1//8], |
| 546 | + |
| 547 | + (mean, 0, true) => Float32[4 4 5 5 6; |
| 548 | + 6 6 7 7 8], |
| 549 | + (mean, 0, false) => Float32[3 2 5 1 4; |
| 550 | + 3 2 5 1 4], |
| 551 | + ) |
| 552 | + #! format: on |
| 553 | + test_scatter(dsts, srcs, idxs, res; dims=[0]) |
| 554 | + end |
| 555 | + |
| 556 | + @testset "scatter 2d+3d src, 2d index => 1d+2d output" begin |
| 557 | + #! format: off |
| 558 | + dsts = Dict( |
| 559 | + 0 => Float32[3, 4, 5, 6, 7], |
| 560 | + 1 => Float32[3 3 4 4 5; |
| 561 | + 5 5 6 6 7], |
| 562 | + ) |
| 563 | + |
| 564 | + srcs = Dict( |
| 565 | + (0, true) => ones(Float32, 3, 4), |
| 566 | + (0, false) => ones(Float32, 3) * collect(1:4)', |
| 567 | + (1, true) => ones(Float32, 2, 3, 4), |
| 568 | + (1, false) => Float32[1, 2] .* reshape(ones(Float32, 3) * collect(1:4)', 1,3,4), |
| 569 | + ) |
| 570 | + |
| 571 | + idxs = Dict( |
| 572 | + :int => [1 2 3 4; |
| 573 | + 4 2 1 3; |
| 574 | + 3 5 5 3], |
| 575 | + :tup => [(1,) (2,) (3,) (4,); |
| 576 | + (4,) (2,) (1,) (3,); |
| 577 | + (3,) (5,) (5,) (3,)], |
| 578 | + :car => CartesianIndex.( |
| 579 | + [(1,) (2,) (3,) (4,); |
| 580 | + (4,) (2,) (1,) (3,); |
| 581 | + (3,) (5,) (5,) (3,)]), |
| 582 | + ) |
| 583 | + |
| 584 | + res = Dict( |
| 585 | + (+, 0, true) => Float32[5, 6, 9, 8, 9], |
| 586 | + (+, 1, true) => Float32[5 5 8 6 7; |
| 587 | + 7 7 10 8 9], |
| 588 | + (+, 0, false) => Float32[4, 4, 12, 5, 5], |
| 589 | + (+, 1, false) => Float32[4 4 12 5 5; |
| 590 | + 8 8 24 10 10], |
| 591 | + (-, 0, true) => Float32[1, 2, 1, 4, 5], |
| 592 | + (-, 1, true) => Float32[1 1 0 2 3; |
| 593 | + 3 3 2 4 5], |
| 594 | + (-, 0, false) => Float32[-4, -4, -12, -5, -5], |
| 595 | + (-, 1, false) => Float32[-4 -4 -12 -5 -5; |
| 596 | + -8 -8 -24 -10 -10], |
| 597 | + (max, 0, true) => Float32[3, 4, 5, 6, 7], |
| 598 | + (max, 1, true) => Float32[3 3 4 4 5; |
| 599 | + 5 5 6 6 7], |
| 600 | + (max, 0, false) => Float32[3, 2, 4, 4, 3], |
| 601 | + (max, 1, false) => Float32[3 2 4 4 3; |
| 602 | + 6 4 8 8 6], |
| 603 | + (min, 0, true) => Float32[1, 1, 1, 1, 1], |
| 604 | + (min, 1, true) => Float32[1 1 1 1 1; |
| 605 | + 1 1 1 1 1], |
| 606 | + (min, 0, false) => Float32[1, 2, 1, 1, 2], |
| 607 | + (min, 1, false) => Float32[1 2 1 1 2; |
| 608 | + 2 4 2 2 4], |
| 609 | + (*, 0, true) => Float32[3, 4, 5, 6, 7], |
| 610 | + (*, 1, true) => Float32[3 3 4 4 5; |
| 611 | + 5 5 6 6 7], |
| 612 | + (*, 0, false) => Float32[3, 4, 48, 4, 6], |
| 613 | + (*, 1, false) => Float32[3 4 48 4 6; |
| 614 | + 12 16 768 16 24], |
| 615 | + (/, 0, true) => Float32[0.75, 1., 0.3125, 1.5, 1.75], |
| 616 | + (/, 1, true) => Float32[0.75 0.75 0.25 1. 1.25; |
| 617 | + 1.25 1.25 0.375 1.5 1.75], |
| 618 | + (/, 0, false) => Float32[1//12, 1//16, 1//768, 1//16, 1//24], |
| 619 | + (/, 1, false) => Float32[1//12 1//16 1//768 1//16 1//24; |
| 620 | + 1//48 1//64 1//12288 1//64 1//96], |
| 621 | + (mean, 0, true) => Float32[4., 5., 6., 7., 8.], |
| 622 | + (mean, 1, true) => Float32[4. 4. 5. 5. 6.; |
| 623 | + 6. 6. 7. 7. 8.], |
| 624 | + (mean, 0, false) => Float32[2, 2, 3, 2.5, 2.5], |
| 625 | + (mean, 1, false) => Float32[2. 2. 3. 2.5 2.5; |
| 626 | + 4. 4. 6. 5. 5.], |
| 627 | + ) |
| 628 | + #! format: on |
| 629 | + |
| 630 | + test_scatter(dsts, srcs, idxs, res; dims=[0, 1]) |
| 631 | + end |
| 632 | +end |
| 633 | + |
384 | 634 | @testset "∇conv(D = $ndim)" for ndim in 1:3
|
385 | 635 | x_spatial_dim = 4
|
386 | 636 | batch_size = 2
|
|
0 commit comments