@@ -2828,122 +2828,6 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) {
28282828
28292829// -----
28302830
2831- #blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
2832- #mma = #ttg.nvidia_mma <{versionMajor = 2 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 16 ]}>
2833-
2834- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " cuda:100" } {
2835-
2836- // CHECK-LABEL: @hoist_one_conditional
2837- tt.func @hoist_one_conditional (
2838- %arg0: i1 ,
2839- %arg1: tensor <128 x32 x!tt.ptr <f32 >, #blocked >,
2840- %arg2: tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>,
2841- %arg3: tensor <128 x128 xf32 , #mma >
2842- ) -> tensor <128 x128 xf32 , #mma > {
2843-
2844- // CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op
2845- %cst = arith.constant dense <0.000000e+00 > : tensor <128 x32 xf32 , #blocked >
2846- // CHECK: scf.if
2847- %0 = scf.if %arg0 -> (tensor <128 x32 xf32 , #blocked >) {
2848- // CHECK-NEXT: [[RES:%.*]] = tt.load
2849- %3 = tt.load %arg1 : tensor <128 x32 x!tt.ptr <f32 >, #blocked >
2850- // CHECK-NEXT: ttg.convert_layout [[RES]]
2851- // CHECK-NEXT: yield
2852- scf.yield %3 : tensor <128 x32 xf32 , #blocked >
2853- } else {
2854- scf.yield %cst : tensor <128 x32 xf32 , #blocked >
2855- }
2856- // CHECK-NOT: ttg.convert_layout
2857- %1 = ttg.convert_layout %0 : tensor <128 x32 xf32 , #blocked > -> tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
2858- %2 = tt.dot %1 , %arg2 , %arg3 : tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
2859- tt.return %2 : tensor <128 x128 xf32 , #mma >
2860- }
2861-
2862- // CHECK-LABEL: @hoist_multiple_conditional
2863- tt.func @hoist_multiple_conditional (
2864- %arg0: i1 ,
2865- %arg1: i1 ,
2866- %arg2: tensor <128 x32 x!tt.ptr <f32 >, #blocked >,
2867- %arg3: tensor <128 x32 x!tt.ptr <f32 >, #blocked >,
2868- %arg4: tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>,
2869- %arg5: tensor <128 x128 xf32 , #mma >
2870- ) -> tensor <128 x128 xf32 , #mma > {
2871- // CHECK-COUNT-1: ttg.convert_layout
2872- %cst0 = arith.constant dense <1.0 > : tensor <128 x32 xf32 , #blocked >
2873- %cst1 = arith.constant dense <2.0 > : tensor <128 x32 xf32 , #blocked >
2874- %0 = scf.if %arg0 -> (tensor <128 x32 xf32 , #blocked >) {
2875- %3 = tt.load %arg2 : tensor <128 x32 x!tt.ptr <f32 >, #blocked >
2876- scf.yield %3 : tensor <128 x32 xf32 , #blocked >
2877- } else {
2878- scf.yield %cst0 : tensor <128 x32 xf32 , #blocked >
2879- }
2880- %1 = scf.if %arg1 -> (tensor <128 x32 xf32 , #blocked >) {
2881- %4 = tt.load %arg3 : tensor <128 x32 x!tt.ptr <f32 >, #blocked >
2882- scf.yield %4 : tensor <128 x32 xf32 , #blocked >
2883- } else {
2884- scf.yield %cst1 : tensor <128 x32 xf32 , #blocked >
2885- }
2886- %2 = arith.addf %0 , %1 : tensor <128 x32 xf32 , #blocked >
2887- %3 = ttg.convert_layout %2 : tensor <128 x32 xf32 , #blocked > -> tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
2888- %4 = tt.dot %3 , %arg4 , %arg5 : tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
2889- tt.return %4 : tensor <128 x128 xf32 , #mma >
2890- }
2891-
2892- // CHECK-LABEL: @hoist_across_loop
2893- tt.func @hoist_across_loop (
2894- %arg0: i1 ,
2895- %arg1: tensor <128 x32 x!tt.ptr <f32 >, #blocked >,
2896- %arg2: tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>,
2897- %arg3: tensor <128 x128 xf32 , #mma >
2898- ) -> tensor <128 x128 xf32 , #mma > {
2899- // CHECK: arith.constant {{.*}} tensor<128x32xf32, #ttg.dot_op
2900- %cst = arith.constant dense <1.0 > : tensor <128 x32 xf32 , #blocked >
2901- %c0_i32 = arith.constant 0 : i32
2902- %c1_i32 = arith.constant 1 : i32
2903- %c32_i32 = arith.constant 32 : i32
2904- // CHECK: scf.for
2905- %0:2 = scf.for %i = %c0_i32 to %c32_i32 step %c1_i32 iter_args (%arg4 = %cst , %acc = %arg3 ) -> (tensor <128 x32 xf32 , #blocked >, tensor <128 x128 xf32 , #mma >) : i32 {
2906- // CHECK-NEXT: scf.if
2907- %1 = scf.if %arg0 -> (tensor <128 x32 xf32 , #blocked >) {
2908- // CHECK-NEXT: [[RES:%.*]] = tt.load
2909- // CHECK-NEXT: ttg.convert_layout [[RES]]
2910- %3 = tt.load %arg1 : tensor <128 x32 x!tt.ptr <f32 >, #blocked >
2911- scf.yield %3 : tensor <128 x32 xf32 , #blocked >
2912- } else {
2913- scf.yield %arg4 : tensor <128 x32 xf32 , #blocked >
2914- }
2915- // CHECK-NOT: ttg.convert_layout
2916- %2 = ttg.convert_layout %1 : tensor <128 x32 xf32 , #blocked > -> tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>>
2917- %3 = tt.dot %2 , %arg2 , %acc : tensor <128 x32 xf32 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 2 }>> * tensor <32 x128 xf32 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <128 x128 xf32 , #mma >
2918- scf.yield %1 , %3 : tensor <128 x32 xf32 , #blocked >, tensor <128 x128 xf32 , #mma >
2919- }
2920- tt.return %0#1 : tensor <128 x128 xf32 , #mma >
2921- }
2922-
2923- // CHECK-LABEL: @chained_if
2924- tt.func @chained_if (%arg0: i1 , %arg1: i1 , %arg2: tensor <32 x32 x!tt.ptr <f32 >, #blocked >, %arg3: tensor <32 x32 x!tt.ptr <f32 >, #blocked >) -> tensor <32 x32 xf32 , #mma > {
2925- // CHECK-COUNT-1: ttg.convert_layout
2926- %cst = arith.constant dense <1.0 > : tensor <32 x32 xf32 , #blocked >
2927- %0 = scf.if %arg0 -> tensor <32 x32 xf32 , #blocked > {
2928- %anchor = tt.load %arg2 : tensor <32 x32 x!tt.ptr <f32 >, #blocked >
2929- scf.yield %anchor : tensor <32 x32 xf32 , #blocked >
2930- } else {
2931- scf.yield %cst : tensor <32 x32 xf32 , #blocked >
2932- }
2933- %1 = scf.if %arg1 -> tensor <32 x32 xf32 , #blocked > {
2934- %anchor = tt.load %arg3 : tensor <32 x32 x!tt.ptr <f32 >, #blocked >
2935- scf.yield %anchor : tensor <32 x32 xf32 , #blocked >
2936- } else {
2937- scf.yield %0 : tensor <32 x32 xf32 , #blocked >
2938- }
2939- %2 = ttg.convert_layout %1 : tensor <32 x32 xf32 , #blocked > -> tensor <32 x32 xf32 , #mma >
2940- tt.return %2 : tensor <32 x32 xf32 , #mma >
2941- }
2942-
2943- }
2944-
2945- // -----
2946-
29472831#linear = #ttg.linear <{register = [[1 , 0 ], [0 , 8 ], [0 , 16 ]], lane = [[2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [0 , 1 ]], warp = [[0 , 2 ], [0 , 4 ]], block = []}>
29482832#blocked = #ttg.blocked <{sizePerThread = [2 , 4 ], threadsPerWarp = [16 , 2 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ]}>
29492833
0 commit comments