@@ -54,7 +54,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
54
54
55
55
#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 4 , threadsPerWarp = 16 , warpsPerCTA = [4 , 4 ], repCluster = [2 , 2 ]}>
56
56
#dot_a = #ttg.dot_op <{opIdx = 0 , parent = #dpas , kWidth = 2 }>
57
- module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 16 : i32 } {
57
+ module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 16 : i32 , " ttg.threads-per-warp " = 16 : i32 } {
58
58
// CHECK-LABEL: @regular_pointer_block_io
59
59
tt.func public @regular_pointer_block_io (%arg0: !tt.ptr <i8 >) {
60
60
%0 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #dot_a }>>
@@ -69,7 +69,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
69
69
%9 = tt.splat %arg0 : !tt.ptr <i8 > -> tensor <256 x64 x!tt.ptr <i8 >, #dot_a >
70
70
%addr = tt.addptr %9 , %8 : tensor <256 x64 x!tt.ptr <i8 >, #dot_a >, tensor <256 x64 xi32 , #dot_a >
71
71
%cst = arith.constant dense <0 > : tensor <256 x64 xi8 , #dot_a >
72
- // CHECK-COUNT-32 : triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
72
+ // CHECK-COUNT-16 : triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
73
73
tt.store %addr , %cst {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <i8 >, #dot_a >
74
74
75
75
tt.return
@@ -80,7 +80,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
80
80
81
81
#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [4 , 4 ], repCluster = [2 , 2 ]}>
82
82
#dot_a = #ttg.dot_op <{opIdx = 0 , parent = #dpas , kWidth = 1 }>
83
- module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 16 : i32 } {
83
+ module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 16 : i32 , " ttg.threads-per-warp " = 16 : i32 } {
84
84
// CHECK-LABEL: @regular_pointer_block_io
85
85
tt.func public @regular_pointer_block_io (%arg0: !tt.ptr <f32 >) {
86
86
%0 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #dot_a }>>
@@ -95,7 +95,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
95
95
%9 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <256 x64 x!tt.ptr <f32 >, #dot_a >
96
96
%addr = tt.addptr %9 , %8 : tensor <256 x64 x!tt.ptr <f32 >, #dot_a >, tensor <256 x64 xi32 , #dot_a >
97
97
%cst = arith.constant dense <0.000000e+00 > : tensor <256 x64 xf32 , #dot_a >
98
- // CHECK-COUNT-128 : triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
98
+ // CHECK-COUNT-64 : triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
99
99
tt.store %addr , %cst {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <f32 >, #dot_a >
100
100
101
101
tt.return
@@ -106,7 +106,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
106
106
107
107
#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [4 , 4 ], repCluster = [2 , 2 ]}>
108
108
#dot_b = #ttg.dot_op <{opIdx = 1 , parent = #dpas , kWidth = 1 }>
109
- module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 16 : i32 } {
109
+ module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 16 : i32 , " ttg.threads-per-warp " = 16 : i32 } {
110
110
// CHECK-LABEL: @regular_pointer_block_io
111
111
tt.func public @regular_pointer_block_io (%arg0: !tt.ptr <f32 >) {
112
112
%0 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #dot_b }>>
@@ -121,7 +121,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
121
121
%9 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <256 x64 x!tt.ptr <f32 >, #dot_b >
122
122
%addr = tt.addptr %9 , %8 : tensor <256 x64 x!tt.ptr <f32 >, #dot_b >, tensor <256 x64 xi32 , #dot_b >
123
123
%cst = arith.constant dense <0.000000e+00 > : tensor <256 x64 xf32 , #dot_b >
124
- // CHECK-COUNT-128 : triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
124
+ // CHECK-COUNT-64 : triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
125
125
tt.store %addr , %cst {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <f32 >, #dot_b >
126
126
127
127
tt.return
@@ -131,7 +131,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
131
131
// -----
132
132
133
133
#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 1 , threadsPerWarp = 16 , warpsPerCTA = [4 , 4 ], repCluster = [2 , 2 ]}>
134
- module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 16 : i32 } {
134
+ module attributes {ttig.support_sg_2d_block , " ttg.num-warps" = 16 : i32 , " ttg.threads-per-warp " = 16 : i32 } {
135
135
// CHECK-LABEL: @regular_pointer_block_io
136
136
tt.func public @regular_pointer_block_io (%arg0: !tt.ptr <f32 >) {
137
137
%0 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #dpas }>>
@@ -146,7 +146,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32} {
146
146
%9 = tt.splat %arg0 : !tt.ptr <f32 > -> tensor <256 x64 x!tt.ptr <f32 >, #dpas >
147
147
%addr = tt.addptr %9 , %8 : tensor <256 x64 x!tt.ptr <f32 >, #dpas >, tensor <256 x64 xi32 , #dpas >
148
148
%cst = arith.constant dense <0.000000e+00 > : tensor <256 x64 xf32 , #dpas >
149
- // CHECK-COUNT-32 : triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
149
+ // CHECK-COUNT-16 : triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
150
150
tt.store %addr , %cst {ttig.block_io = " row_major" } : tensor <256 x64 x!tt.ptr <f32 >, #dpas >
151
151
152
152
tt.return
0 commit comments