Skip to content

Commit 96ceeb3

Browse files
[TritonGEN] Update 2D block verifier (#4669)
- Invoked `verify2DBlockAddressPayloadRestriction` in `verify2DBlockHWRestriction` - Added 64-bit element size support to 2D block operations - Removed non-HW specific restrictions --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 0169c00 commit 96ceeb3

File tree

6 files changed

+211
-114
lines changed

6 files changed

+211
-114
lines changed

test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
126126
// CHECK: %[[VBLOCKS:.*]] = llvm.mlir.constant(2 : i32) : i32
127127
// CHECK: %[[TRANSPOSE:.*]] = llvm.mlir.constant(false) : i1
128128
// CHECK: %[[VNNI:.*]] = llvm.mlir.constant(false) : i1
129-
// CHECK: %[[VAL_68:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ELEM_BITS]], %[[TILE_WIDTH]], %[[TILE_HEIGHT]], %[[VBLOCKS]], %[[TRANSPOSE]], %[[VNNI]], {{.*}})
129+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ELEM_BITS]], %[[TILE_WIDTH]], %[[TILE_HEIGHT]], %[[VBLOCKS]], %[[TRANSPOSE]], %[[VNNI]], {{.*}})
130130
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi8>
131131
llvm.return
132132
}
@@ -155,7 +155,7 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
155155
// CHECK: %[[VBLOCKS:.*]] = llvm.mlir.constant(4 : i32) : i32
156156
// CHECK: %[[TRANSPOSE:.*]] = llvm.mlir.constant(false) : i1
157157
// CHECK: %[[VNNI:.*]] = llvm.mlir.constant(false) : i1
158-
// CHECK: %[[VAL_68:.*]] = llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ELEM_BITS]], %[[TILE_WIDTH]], %[[TILE_HEIGHT]], %[[VBLOCKS]], %[[TRANSPOSE]], %[[VNNI]], {{.*}})
158+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[ELEM_BITS]], %[[TILE_WIDTH]], %[[TILE_HEIGHT]], %[[VBLOCKS]], %[[TRANSPOSE]], %[[VNNI]], {{.*}})
159159
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=8, tile_height=16, v_blocks=4, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16>
160160
llvm.return
161161
}
@@ -424,6 +424,20 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
424424

425425
// -----
426426

427+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
428+
// CHECK-COUNT-2: llvm.mlir.constant(1 : i32) : i32
429+
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32
430+
// CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(8 : i32) : i32
431+
// CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(32 : i32) : i32
432+
// CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32
433+
// CHECK-NEXT: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], {{.*}}, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
434+
// CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<4xi32>
435+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=32, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<4xi32>
436+
llvm.return
437+
}
438+
439+
// -----
440+
427441
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
428442
// CHECK-COUNT-2: llvm.mlir.constant(1 : i32) : i32
429443
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -522,7 +536,7 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
522536

523537
// -----
524538

525-
llvm.func @triton_gen.2Dblockload_(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
539+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
526540
// CHECK: llvm.mlir.constant(4 : i32) : i32
527541
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(4 : i32) : i32
528542
// CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(8 : i32) : i32
@@ -535,6 +549,19 @@ llvm.func @triton_gen.2Dblockload_(%ptr : !llvm.ptr<1>, %base_width : i32, %base
535549

536550
// -----
537551

552+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
553+
// CHECK: llvm.mlir.constant(8 : i32) : i32
554+
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(8 : i32) : i32
555+
// CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(4 : i32) : i32
556+
// CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(8 : i32) : i32
557+
// CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32
558+
// CHECK-NEXT: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], {{.*}}, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
559+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<4xi32>
560+
llvm.return
561+
}
562+
563+
// -----
564+
538565
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
539566
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv(
540567
// CHECK-SAME: triton_gen.DecorationCacheControlINTEL = #triton_gen.decoration_cache_control<#triton_gen.load_cache_control<0, Uncached, 4>, #triton_gen.load_cache_control<1, Uncached, 4>>

test/TritonGEN/tritongen-2Dblockprefetch-to-llvm.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,39 +76,39 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %b
7676
// CHECK: [[BASEWIDTH:%.*]] = llvm.sub [[ADD]], [[ONE0]] : i32
7777
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(16 : i32) : i32
7878
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
79-
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(2 : i32) : i32
79+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(1 : i32) : i32
8080
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
8181
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
8282
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
8383
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid([[BASE_ALIGNED]], [[BASEWIDTH]], {{.*}}, [[X]], {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
84+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=8, tile_height=1, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
85+
llvm.return
86+
}
87+
88+
// -----
89+
90+
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
91+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(16 : i32) : i32
92+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
93+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(2 : i32) : i32
94+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
95+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
96+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
97+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
8498
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=8, tile_height=2, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
8599
llvm.return
86100
}
87101

88102
// -----
89103

90104
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
91-
// CHECK: [[ONE0:%.*]] = llvm.mlir.constant(1 : i32) : i32
92-
// CHECK: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
93-
// CHECK: [[VAL_63:%.*]] = llvm.mlir.constant(-64 : i64) : i64
94-
// CHECK: [[VAL_64:%.*]] = llvm.and [[PTR]], [[VAL_63]] : i64
95-
// CHECK: [[VAL_65:%.*]] = llvm.inttoptr [[VAL_64]] : i64 to !llvm.ptr<1>
96-
// CHECK: [[CL:%.*]] = llvm.mlir.constant(63 : i64) : i64
97-
// CHECK: [[AND:%.*]] = llvm.and [[PTR]], [[CL]] : i64
98-
// CHECK: [[TRUNC:%.*]] = llvm.trunc [[AND]] : i64 to i32
99-
// CHECK: [[ADD:%.*]] = llvm.add %arg1, [[TRUNC]] : i32
100-
// CHECK: [[TWO:%.*]] = llvm.mlir.constant(2 : i32) : i32
101-
// CHECK: [[SHR:%.*]] = llvm.udiv [[TRUNC]], [[TWO]] : i32
102-
// CHECK: [[X:%.*]] = llvm.add %arg4, [[SHR]] : i32
103-
// CHECK: [[BASE_ALIGNED:%.*]] = llvm.ptrtoint [[VAL_65]] : !llvm.ptr<1> to i64
104-
// CHECK: [[BASEWIDTH:%.*]] = llvm.sub [[ADD]], [[ONE0]] : i32
105105
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(16 : i32) : i32
106106
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
107107
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(4 : i32) : i32
108108
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
109109
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
110110
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
111-
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid([[BASE_ALIGNED]], [[BASEWIDTH]], {{.*}}, [[X]], {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
111+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
112112
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=8, tile_height=4, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
113113
llvm.return
114114
}

test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
11
// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s
22

3+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
4+
// CHECK: [[ONE0:%.*]] = llvm.mlir.constant(1 : i32) : i32
5+
// CHECK: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
6+
// CHECK: [[VAL_63:%.*]] = llvm.mlir.constant(-64 : i64) : i64
7+
// CHECK: [[VAL_64:%.*]] = llvm.and [[PTR]], [[VAL_63]] : i64
8+
// CHECK: [[VAL_65:%.*]] = llvm.inttoptr [[VAL_64]] : i64 to !llvm.ptr<1>
9+
// CHECK: [[CL:%.*]] = llvm.mlir.constant(63 : i64) : i64
10+
// CHECK: [[AND:%.*]] = llvm.and [[PTR]], [[CL]] : i64
11+
// CHECK: [[TRUNC:%.*]] = llvm.trunc [[AND]] : i64 to i32
12+
// CHECK: [[ADD:%.*]] = llvm.add %arg1, [[TRUNC]] : i32
13+
// CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : i32) : i32
14+
// CHECK: [[SHR:%.*]] = llvm.udiv [[TRUNC]], [[ONE]] : i32
15+
// CHECK: [[X:%.*]] = llvm.add %arg4, [[SHR]] : i32
16+
// CHECK: [[BASE_ALIGNED:%.*]] = llvm.ptrtoint [[VAL_65]] : !llvm.ptr<1> to i64
17+
// CHECK: [[BASEWIDTH:%.*]] = llvm.sub [[ADD]], [[ONE0]] : i32
18+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(8 : i32) : i32
19+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
20+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
21+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
22+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
23+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
24+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i16([[BASE_ALIGNED]], [[BASEWIDTH]], {{.*}}, [[X]], {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
25+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
26+
llvm.return
27+
}
28+
29+
// -----
30+
331
// CHECK: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return}
432

533
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
@@ -49,6 +77,34 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
4977

5078
// -----
5179

80+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
81+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(16 : i32) : i32
82+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(32 : i32) : i32
83+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(1 : i32) : i32
84+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
85+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
86+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
87+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
88+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=32, tile_height=1, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
89+
llvm.return
90+
}
91+
92+
// -----
93+
94+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
95+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(16 : i32) : i32
96+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
97+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
98+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
99+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
100+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
101+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
102+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
103+
llvm.return
104+
}
105+
106+
// -----
107+
52108
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
53109
// CHECK: llvm.mlir.constant(2 : i32) : i32
54110
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(2 : i32) : i32
@@ -62,6 +118,34 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
62118

63119
// -----
64120

121+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
122+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(32 : i32) : i32
123+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(4 : i32) : i32
124+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
125+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
126+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
127+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
128+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
129+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=4, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
130+
llvm.return
131+
}
132+
133+
// -----
134+
135+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
136+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(32 : i32) : i32
137+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
138+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
139+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
140+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
141+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
142+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
143+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
144+
llvm.return
145+
}
146+
147+
// -----
148+
65149
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi32>) {
66150
// CHECK: llvm.mlir.constant(4 : i32) : i32
67151
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(4 : i32) : i32

0 commit comments

Comments
 (0)