@@ -3789,20 +3789,36 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} {
3789
3789
3790
3790
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 , 2 ], threadsPerWarp = [2 , 16 , 1 ], warpsPerCTA = [1 , 1 , 1 ], order = [2 , 1 , 0 ]}>
3791
3791
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ]}>
3792
- #linear = #ttg.linear <{ register = [], lane = [[ 0 , 1 ], [ 0 , 2 ], [ 0 , 4 ], [ 0 , 8 ], [1 , 0 ]], warp = [], block = [ ]}>
3792
+ #blocked2 = #ttg.blocked <{ sizePerThread = [1 , 2 ], threadsPerWarp = [ 2 , 16 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ]}>
3793
3793
module attributes {" ttg.num-warps" = 1 : i32 , ttg.target = " cuda:80" } {
3794
3794
// CHECK-LABEL: join_forward
3795
- tt.func @join_forward (%arg0: tensor <2 x16 xf32 , #linear >) -> tensor <2 x16 x2 xf32 , #blocked > {
3796
- // CHECK-LABEL: tt.join
3797
- // CHECK-LABEL: ttg.convert_layout
3798
- %0 = ttg.convert_layout %arg0 : tensor <2 x16 xf32 , #linear > -> tensor <2 x16 xf32 , #blocked1 >
3795
+ tt.func @join_forward (%arg0: tensor <2 x16 xf32 , #blocked2 >) -> tensor <2 x16 x2 xf32 , #blocked > {
3796
+ // CHECK: tt.join
3797
+ // CHECK: ttg.convert_layout
3798
+ // CHECK: tt.return
3799
+ %0 = ttg.convert_layout %arg0 : tensor <2 x16 xf32 , #blocked2 > -> tensor <2 x16 xf32 , #blocked1 >
3799
3800
%1 = tt.join %0 , %0 : tensor <2 x16 xf32 , #blocked1 > -> tensor <2 x16 x2 xf32 , #blocked >
3800
3801
tt.return %1 : tensor <2 x16 x2 xf32 , #blocked >
3801
3802
}
3802
3803
}
3803
3804
3804
3805
// -----
3805
3806
3807
+ #blocked = #ttg.blocked <{sizePerThread = [1 , 32 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [0 , 1 ]}>
3808
+ #blocked1 = #ttg.blocked <{sizePerThread = [1 , 32 , 2 ], threadsPerWarp = [32 , 1 , 1 ], warpsPerCTA = [4 , 1 , 1 ], order = [0 , 1 , 2 ]}>
3809
+ #blocked2 = #ttg.blocked <{sizePerThread = [1 , 32 , 2 ], threadsPerWarp = [32 , 1 , 1 ], warpsPerCTA = [4 , 1 , 1 ], order = [2 , 0 , 1 ]}>
3810
+ module attributes {" ttg.num-warps" = 4 : i32 , ttg.target = " cuda:80" } {
3811
+ // CHECK-LABEL: join_backward
3812
+ tt.func @join_backward (%arg0: tensor <128 x32 xf16 , #blocked >, %arg1: tensor <128 x32 xf16 , #blocked >) -> tensor <128 x32 x2 xf16 , #blocked1 > {
3813
+ // CHECK: %[[JOIN:.*]] = tt.join
3814
+ // CHECK: tt.return %[[JOIN]]
3815
+ %0 = tt.join %arg0 , %arg1 : tensor <128 x32 xf16 , #blocked > -> tensor <128 x32 x2 xf16 , #blocked2 >
3816
+ %1 = ttg.convert_layout %0 : tensor <128 x32 x2 xf16 , #blocked2 > -> tensor <128 x32 x2 xf16 , #blocked1 >
3817
+ tt.return %1 : tensor <128 x32 x2 xf16 , #blocked1 >
3818
+ }
3819
+ }
3820
+ // -----
3821
+
3806
3822
#linear = #ttg.linear <{register = [[0 , 2 ], [64 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [0 , 1 ]], warp = [[0 , 0 ], [32 , 0 ]], block = []}>
3807
3823
#linear1 = #ttg.linear <{register = [[0 , 2 ], [64 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ], [16 , 0 ], [0 , 1 ]], warp = [[32 , 0 ], [0 , 0 ]], block = []}>
3808
3824
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 64 ], warpsPerCTA = [2 , 2 ], order = [1 , 0 ]}>
0 commit comments