|
1 | 1 | import expecttest
|
| 2 | +from triton.runtime.jit import MockTensor |
2 | 3 | import torch
|
3 | 4 | import pytest
|
4 | 5 | import re
|
@@ -600,3 +601,183 @@ def kernel():
|
600 | 601 | }
|
601 | 602 | }
|
602 | 603 | """)
|
| 604 | + |
| 605 | + |
| 606 | +@gluon.jit |
| 607 | +def broadcast_kernel(): |
| 608 | + layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0]) |
| 609 | + a = ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, layout))[None, :] |
| 610 | + b = ttgl.arange(0, 16, layout=ttgl.SliceLayout(1, layout))[:, None] |
| 611 | + 0 + a + b |
| 612 | + |
| 613 | + |
| 614 | +def test_broadcast(fresh_knobs): |
| 615 | + knobs.compilation.disable_line_info = True |
| 616 | + |
| 617 | + h = broadcast_kernel.warmup(sanitize_overflow=False, grid=(1, )) |
| 618 | + expecttest.assert_expected_inline( |
| 619 | + anonymize_ir(h.asm["source"]), """\ |
| 620 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> |
| 621 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { |
| 622 | + tt.func public @broadcast_kernel() attributes {noinline = false} { |
| 623 | + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 624 | + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc) |
| 625 | + %2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc) |
| 626 | + %3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> loc(#loc) |
| 627 | + %c0_i32 = arith.constant 0 : i32 loc(#loc) |
| 628 | + %c0_i32_0 = arith.constant 0 : i32 loc(#loc) |
| 629 | + %cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> loc(#loc) |
| 630 | + %4 = arith.addi %cst, %1 : tensor<1x16xi32, #blocked> loc(#loc) |
| 631 | + %5 = tt.broadcast %4 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> loc(#loc) |
| 632 | + %6 = tt.broadcast %3 : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked> loc(#loc) |
| 633 | + %7 = arith.addi %5, %6 : tensor<16x16xi32, #blocked> loc(#loc) |
| 634 | + tt.return loc(#loc) |
| 635 | + } loc(#loc) |
| 636 | +} loc(#loc) |
| 637 | +#loc = loc(unknown) |
| 638 | +""") |
| 639 | + |
| 640 | + |
| 641 | +@gluon.jit |
| 642 | +def math_kernel(): |
| 643 | + layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0]) |
| 644 | + a = ttgl.full([16, 16], 1, ttgl.float32, layout) |
| 645 | + b = ttgl.full([16, 16], 2, ttgl.float32, layout) |
| 646 | + c = ttgl.full([16, 16], 4, ttgl.float32, layout) |
| 647 | + d = ttgl.full([16, 16], 1, ttgl.int32, layout) |
| 648 | + e = ttgl.full([16, 16], 1, ttgl.int32, layout) |
| 649 | + ttgl.umulhi(d, e) |
| 650 | + ttgl.exp(a) |
| 651 | + ttgl.exp2(a) |
| 652 | + ttgl.log(a) |
| 653 | + ttgl.log2(a) |
| 654 | + ttgl.cos(a) |
| 655 | + ttgl.sin(a) |
| 656 | + ttgl.sqrt(a) |
| 657 | + ttgl.sqrt_rn(a) |
| 658 | + ttgl.rsqrt(a) |
| 659 | + ttgl.abs(a) |
| 660 | + ttgl.fdiv(a, b) |
| 661 | + ttgl.div_rn(a, b) |
| 662 | + ttgl.erf(a) |
| 663 | + ttgl.floor(a) |
| 664 | + ttgl.ceil(a) |
| 665 | + ttgl.fma(a, b, c) |
| 666 | + |
| 667 | + |
| 668 | +def test_math(fresh_knobs): |
| 669 | + knobs.compilation.disable_line_info = True |
| 670 | + |
| 671 | + h = math_kernel.warmup(sanitize_overflow=False, grid=(1, )) |
| 672 | + expecttest.assert_expected_inline( |
| 673 | + anonymize_ir(h.asm["source"]), """\ |
| 674 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> |
| 675 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { |
| 676 | + tt.func public @math_kernel() attributes {noinline = false} { |
| 677 | + %cst = arith.constant 1.000000e+00 : f32 loc(#loc) |
| 678 | + %cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc) |
| 679 | + %cst_1 = arith.constant 2.000000e+00 : f32 loc(#loc) |
| 680 | + %cst_2 = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc) |
| 681 | + %cst_3 = arith.constant 4.000000e+00 : f32 loc(#loc) |
| 682 | + %cst_4 = arith.constant dense<4.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc) |
| 683 | + %c1_i32 = arith.constant 1 : i32 loc(#loc) |
| 684 | + %cst_5 = arith.constant dense<1> : tensor<16x16xi32, #blocked> loc(#loc) |
| 685 | + %c1_i32_6 = arith.constant 1 : i32 loc(#loc) |
| 686 | + %cst_7 = arith.constant dense<1> : tensor<16x16xi32, #blocked> loc(#loc) |
| 687 | + %0 = tt.mulhiui %cst_5, %cst_7 : tensor<16x16xi32, #blocked> loc(#loc) |
| 688 | + %1 = math.exp %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 689 | + %2 = math.exp2 %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 690 | + %3 = math.log %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 691 | + %4 = math.log2 %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 692 | + %5 = math.cos %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 693 | + %6 = math.sin %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 694 | + %7 = math.sqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 695 | + %8 = tt.precise_sqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 696 | + %9 = math.rsqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 697 | + %10 = math.absf %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 698 | + %11 = arith.divf %cst_0, %cst_2 : tensor<16x16xf32, #blocked> loc(#loc) |
| 699 | + %12 = tt.precise_divf %cst_0, %cst_2 : tensor<16x16xf32, #blocked> loc(#loc) |
| 700 | + %13 = math.erf %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 701 | + %14 = math.floor %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 702 | + %15 = math.ceil %cst_0 : tensor<16x16xf32, #blocked> loc(#loc) |
| 703 | + %16 = math.fma %cst_0, %cst_2, %cst_4 : tensor<16x16xf32, #blocked> loc(#loc) |
| 704 | + tt.return loc(#loc) |
| 705 | + } loc(#loc) |
| 706 | +} loc(#loc) |
| 707 | +#loc = loc(unknown) |
| 708 | +""") |
| 709 | + |
| 710 | + |
| 711 | +@gluon.jit |
| 712 | +def pair_add(a0, a1, b0, b1): |
| 713 | + return a0 + b0, a1 + b1 |
| 714 | + |
| 715 | + |
| 716 | +@gluon.jit |
| 717 | +def reduce_kernel(out): |
| 718 | + layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0]) |
| 719 | + a = ttgl.full([16, 16], 1, ttgl.float32, layout) |
| 720 | + b = ttgl.full([16, 16], 2, ttgl.float32, layout) |
| 721 | + s0 = ttgl.sum(a, 0) |
| 722 | + ttgl.static_assert(s0.type.layout == ttgl.SliceLayout(0, layout)) |
| 723 | + s1 = ttgl.sum(a, 1) |
| 724 | + ttgl.static_assert(s1.type.layout == ttgl.SliceLayout(1, layout)) |
| 725 | + |
| 726 | + scalar = ttgl.max(s0, 0) |
| 727 | + ttgl.static_assert(scalar.type == ttgl.float32) |
| 728 | + |
| 729 | + s1 = ttgl.convert_layout(s1, s0.type.layout) |
| 730 | + |
| 731 | + pairs = ttgl.reduce((a, b), 0, pair_add) |
| 732 | + ttgl.static_assert(pairs[0].type.layout == ttgl.SliceLayout(0, layout)) |
| 733 | + ttgl.static_assert(pairs[1].type.layout == ttgl.SliceLayout(0, layout)) |
| 734 | + result = scalar + s1 + pairs[0] + pairs[1] |
| 735 | + tl.store(out + ttgl.arange(0, 16, s0.type.layout), result) |
| 736 | + |
| 737 | + |
| 738 | +def test_reduce(fresh_knobs): |
| 739 | + knobs.compilation.disable_line_info = True |
| 740 | + |
| 741 | + h = reduce_kernel.warmup(MockTensor(ttgl.float32), sanitize_overflow=False, grid=(1, )) |
| 742 | + expecttest.assert_expected_inline( |
| 743 | + anonymize_ir(h.asm["ttgir"]), """\ |
| 744 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> |
| 745 | +#loc = loc(unknown) |
| 746 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { |
| 747 | + tt.func public @reduce_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { |
| 748 | + %cst = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc) |
| 749 | + %cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc) |
| 750 | + %0 = "tt.reduce"(%cst_0) <{axis = 0 : i32}> ({ |
| 751 | + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): |
| 752 | + %12 = arith.addf %arg1, %arg2 : f32 loc(#loc) |
| 753 | + tt.reduce.return %12 : f32 loc(#loc) |
| 754 | + }) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 755 | + %1 = "tt.reduce"(%cst_0) <{axis = 1 : i32}> ({ |
| 756 | + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): |
| 757 | + %12 = arith.addf %arg1, %arg2 : f32 loc(#loc) |
| 758 | + tt.reduce.return %12 : f32 loc(#loc) |
| 759 | + }) : (tensor<16x16xf32, #blocked>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc) |
| 760 | + %2 = "tt.reduce"(%0) <{axis = 0 : i32}> ({ |
| 761 | + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown)): |
| 762 | + %12 = arith.maxnumf %arg1, %arg2 : f32 loc(#loc) |
| 763 | + tt.reduce.return %12 : f32 loc(#loc) |
| 764 | + }) : (tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) -> f32 loc(#loc) |
| 765 | + %3 = ttg.convert_layout %1 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 766 | + %4:2 = "tt.reduce"(%cst_0, %cst) <{axis = 0 : i32}> ({ |
| 767 | + ^bb0(%arg1: f32 loc(unknown), %arg2: f32 loc(unknown), %arg3: f32 loc(unknown), %arg4: f32 loc(unknown)): |
| 768 | + %12 = arith.addf %arg1, %arg3 : f32 loc(#loc) |
| 769 | + %13 = arith.addf %arg2, %arg4 : f32 loc(#loc) |
| 770 | + tt.reduce.return %12, %13 : f32, f32 loc(#loc) |
| 771 | + }) : (tensor<16x16xf32, #blocked>, tensor<16x16xf32, #blocked>) -> (tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>>) loc(#loc) |
| 772 | + %5 = tt.splat %2 : f32 -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 773 | + %6 = arith.addf %5, %3 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 774 | + %7 = arith.addf %6, %4#0 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 775 | + %8 = arith.addf %7, %4#1 : tensor<16xf32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 776 | + %9 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 777 | + %10 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 778 | + %11 = tt.addptr %10, %9 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>>, tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 779 | + tt.store %11, %8 : tensor<16x!tt.ptr<f32>, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc) |
| 780 | + tt.return loc(#loc) |
| 781 | + } loc(#loc) |
| 782 | +} loc(#loc) |
| 783 | +""") |
0 commit comments