@@ -635,3 +635,73 @@ def test_broadcast(fresh_knobs):
635
635
} loc(#loc)
636
636
#loc = loc(unknown)
637
637
""" )
638
+
639
+
640
+ @gluon .jit
641
+ def math_kernel ():
642
+ layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 ], [1 , 32 ], [4 , 1 ], [1 , 0 ])
643
+ a = ttgl .full ([16 , 16 ], 1 , ttgl .float32 , layout )
644
+ b = ttgl .full ([16 , 16 ], 2 , ttgl .float32 , layout )
645
+ c = ttgl .full ([16 , 16 ], 4 , ttgl .float32 , layout )
646
+ d = ttgl .full ([16 , 16 ], 1 , ttgl .int32 , layout )
647
+ e = ttgl .full ([16 , 16 ], 1 , ttgl .int32 , layout )
648
+ ttgl .umulhi (d , e )
649
+ ttgl .exp (a )
650
+ ttgl .exp2 (a )
651
+ ttgl .log (a )
652
+ ttgl .log2 (a )
653
+ ttgl .cos (a )
654
+ ttgl .sin (a )
655
+ ttgl .sqrt (a )
656
+ ttgl .sqrt_rn (a )
657
+ ttgl .rsqrt (a )
658
+ ttgl .abs (a )
659
+ ttgl .fdiv (a , b )
660
+ ttgl .div_rn (a , b )
661
+ ttgl .erf (a )
662
+ ttgl .floor (a )
663
+ ttgl .ceil (a )
664
+ ttgl .fma (a , b , c )
665
+
666
+
667
+ def test_math (fresh_knobs ):
668
+ knobs .compilation .disable_line_info = True
669
+
670
+ h = math_kernel .warmup (sanitize_overflow = False , grid = (1 , ))
671
+ expecttest .assert_expected_inline (
672
+ anonymize_ir (h .asm ["source" ]), """\
673
+ #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
674
+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
675
+ tt.func public @math_kernel() attributes {noinline = false} {
676
+ %cst = arith.constant 1.000000e+00 : f32 loc(#loc)
677
+ %cst_0 = arith.constant dense<1.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
678
+ %cst_1 = arith.constant 2.000000e+00 : f32 loc(#loc)
679
+ %cst_2 = arith.constant dense<2.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
680
+ %cst_3 = arith.constant 4.000000e+00 : f32 loc(#loc)
681
+ %cst_4 = arith.constant dense<4.000000e+00> : tensor<16x16xf32, #blocked> loc(#loc)
682
+ %c1_i32 = arith.constant 1 : i32 loc(#loc)
683
+ %cst_5 = arith.constant dense<1> : tensor<16x16xi32, #blocked> loc(#loc)
684
+ %c1_i32_6 = arith.constant 1 : i32 loc(#loc)
685
+ %cst_7 = arith.constant dense<1> : tensor<16x16xi32, #blocked> loc(#loc)
686
+ %0 = tt.mulhiui %cst_5, %cst_7 : tensor<16x16xi32, #blocked> loc(#loc)
687
+ %1 = math.exp %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
688
+ %2 = math.exp2 %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
689
+ %3 = math.log %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
690
+ %4 = math.log2 %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
691
+ %5 = math.cos %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
692
+ %6 = math.sin %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
693
+ %7 = math.sqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
694
+ %8 = tt.precise_sqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
695
+ %9 = math.rsqrt %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
696
+ %10 = math.absf %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
697
+ %11 = arith.divf %cst_0, %cst_2 : tensor<16x16xf32, #blocked> loc(#loc)
698
+ %12 = tt.precise_divf %cst_0, %cst_2 : tensor<16x16xf32, #blocked> loc(#loc)
699
+ %13 = math.erf %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
700
+ %14 = math.floor %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
701
+ %15 = math.ceil %cst_0 : tensor<16x16xf32, #blocked> loc(#loc)
702
+ %16 = math.fma %cst_0, %cst_2, %cst_4 : tensor<16x16xf32, #blocked> loc(#loc)
703
+ tt.return loc(#loc)
704
+ } loc(#loc)
705
+ } loc(#loc)
706
+ #loc = loc(unknown)
707
+ """ )
0 commit comments