Skip to content

Commit 92b6a32

Browse files
[TritonGEN] Use the sub-group-size of the module instead of hard code number of 16 in block load. (#4764)
[TritonGEN] Use the sub-group-size of the module instead of hard code number of 16 in block load. --------- Signed-off-by: Lu,Chengjun <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 9153bd0 commit 92b6a32

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

test/TritonGEN/tritongen-invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,19 +145,23 @@ llvm.func @triton_gen.dpas(%c : vector<8xf32>, %a : vector<8xi16>, %b : vector<8
145145

146146
// -----
147147

148+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
148149
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
149150
// expected-error @+1 {{'triton_gen.2Dblockload' op result size of 256 bits does not match the expected size of 128 bits}}
150151
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xi16>
151152
llvm.return
152153
}
154+
}
153155

154156
// -----
155157

158+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
156159
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
157160
// expected-error @+1 {{'triton_gen.2Dblockload' op transpose and vnni_transform are mutually exclusive}}
158161
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<2xi32>
159162
llvm.return
160163
}
164+
}
161165

162166
// -----
163167

@@ -177,19 +181,23 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
177181

178182
// -----
179183

184+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
180185
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
181186
// expected-error @+1 {{'triton_gen.2Dblockload' op transpose is only supported for 32 and 64 bit elements}}
182187
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
183188
llvm.return
184189
}
190+
}
185191

186192
// -----
187193

194+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
188195
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
189196
// expected-error @+1 {{'triton_gen.2Dblockload' op vnni_transform is only supported for 8 and 16 bit elements}}
190197
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32>
191198
llvm.return
192199
}
200+
}
193201

194202
// -----
195203

@@ -316,11 +324,13 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
316324

317325
// -----
318326

327+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
319328
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
320329
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting result element type to be 32 bits}}
321330
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
322331
llvm.return
323332
}
333+
}
324334

325335
// -----
326336

test/TritonGEN/tritongen.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: triton-opt %s -split-input-file -verify-diagnostics | FileCheck %s
22

3+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
34
llvm.func @triton_gen.barrier() {
45
// CHECK-LABEL: triton_gen.barrier
56
// CHECK: triton_gen.barrier {mem_fence = Local}
@@ -75,3 +76,4 @@ llvm.func @triton_gen.sub_group_block_write(%ptr : !llvm.ptr<3>, %val : i32) {
7576
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, i32
7677
llvm.return
7778
}
79+
}

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "llvm/ADT/STLExtras.h"
1414
#include <cstdint>
1515

16+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
17+
1618
using namespace mlir;
1719
using namespace mlir::triton;
1820

@@ -238,7 +240,8 @@ verify2DBlockLoadHWRestriction(TritonGEN::Matrix2DBlockLoadOp op) {
238240
VectorType resTy = op.getRes().getType();
239241
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
240242
unsigned resSize = resTy.getNumElements() * resElemTySize;
241-
constexpr unsigned subgroupSize = 16;
243+
unsigned subgroupSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
244+
op->getParentOfType<mlir::ModuleOp>());
242245
unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
243246
op.getTileWidth() * op.getVBlocks() / subgroupSize;
244247
if (resSize != expectedSize)

0 commit comments

Comments
 (0)