Skip to content

Commit 86ba555

Browse files
chengjunluanmyachevwhitneywhtsang
authored
[TritonGEN]: Update invalid block store signature. (#4686)
[TritonGEN]: Update invalid block store signature. --------- Signed-off-by: Lu,Chengjun <[email protected]> Co-authored-by: Anatoly Myachev <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent aeb746d commit 86ba555

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,20 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
6464

6565
// -----
6666

67+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
68+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(8 : i32) : i32
69+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(32 : i32) : i32
70+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(1 : i32) : i32
71+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
72+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
73+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
74+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
75+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=1, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
76+
llvm.return
77+
}
78+
79+
// -----
80+
6781
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
6882
// CHECK-COUNT-2: llvm.mlir.constant(1 : i32) : i32
6983
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -118,6 +132,20 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
118132

119133
// -----
120134

135+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
136+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(16 : i32) : i32
137+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(32 : i32) : i32
138+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
139+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
140+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
141+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
142+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockWrite.v8i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
143+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
144+
llvm.return
145+
}
146+
147+
// -----
148+
121149
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
122150
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(32 : i32) : i32
123151
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(4 : i32) : i32

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ static bool isSPVBuiltinAvailable(TritonGEN::Matrix2DBlockStoreOp op) {
166166
op.getTileWidth() == 8 && op.getVBlocks() == 1)
167167
return false;
168168

169+
// intel_sub_group_2d_block_write_16b_8r32x1c
170+
if (op.getElemSizeInBits() == 16 && op.getTileHeight() == 8 &&
171+
op.getTileWidth() == 32 && op.getVBlocks() == 1)
172+
return false;
173+
169174
// intel_sub_group_2d_block_write_32b_8r4x1c
170175
if (op.getElemSizeInBits() == 32 && op.getTileHeight() == 8 &&
171176
op.getTileWidth() == 4 && op.getVBlocks() == 1)
@@ -176,6 +181,12 @@ static bool isSPVBuiltinAvailable(TritonGEN::Matrix2DBlockStoreOp op) {
176181
op.getTileWidth() == 8 && op.getVBlocks() == 1)
177182
return false;
178183

184+
// FIXME: The following signature has correctness issue
185+
// intel_sub_group_2d_block_write_8b_1r32x1c
186+
if (op.getElemSizeInBits() == 8 && op.getTileHeight() == 1 &&
187+
op.getTileWidth() == 32 && op.getVBlocks() == 1)
188+
return false;
189+
179190
return true;
180191
}
181192

0 commit comments

Comments
 (0)