Skip to content

Commit 01d1b36

Browse files
chengjunluetiotto
andauthored
Add i64 type for Triton Gen 2D block io operation because hardware supports 64 bits data size (#4935)
This PR adds support for 64-bit integer (i64) data types to Triton Gen 2D block I/O operations to align with hardware capabilities that support 64-bit data sizes. --------- Signed-off-by: Lu,Chengjun <[email protected]> Signed-off-by: Ettore Tiotto <[email protected]> Co-authored-by: Ettore Tiotto <[email protected]>
1 parent f86f280 commit 01d1b36

File tree

7 files changed

+70
-25
lines changed

7 files changed

+70
-25
lines changed
Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,42 @@
11
// RUN: env TRITON_INTEL_ADVANCED_PATH=1 triton-opt %s --convert-triton-intel-gpu-to-llvm --verify-diagnostics --split-input-file
22

33
module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
4-
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f32>, %arg1: i64, %arg2: i32) {
4+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i64>, %arg1: i64, %arg2: i32) {
55
%c1_i64 = arith.constant 1 : i64
66
%c0_i32 = arith.constant 0 : i32
7-
%22 = tt.make_tensor_ptr %arg0, [%arg1, %arg1], [%arg1, %c1_i64], [%arg2, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2x32xf32>>
8-
// expected-error @+2 {{expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
7+
%22 = tt.make_tensor_ptr %arg0, [%arg1, %arg1], [%arg1, %c1_i64], [%arg2, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2x32xi64>>
8+
// expected-error @+2 {{expecting elem_size_in_bits * tile_width * v_blocks <= 1024}}
99
// expected-error @+1 {{failed to legalize operation 'ttig.prefetch'}}
10-
ttig.prefetch %22 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<2x32xf32>>
10+
ttig.prefetch %22 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<2x32xi64>>
1111
tt.return
1212
}
1313
}
1414

1515
// -----
1616

1717
module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
18-
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f32>, %arg1: i64, %arg2: i32) {
18+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i64>, %arg1: i64, %arg2: i32) {
1919
%c1_i64 = arith.constant 1 : i64
2020
%c0_i32 = arith.constant 0 : i32
21-
%22 = tt.make_tensor_ptr %arg0, [%arg1, %arg1], [%arg1, %c1_i64], [%arg2, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2x32xf32>>
22-
// expected-error @+2 {{expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
21+
%22 = tt.make_tensor_ptr %arg0, [%arg1, %arg1], [%arg1, %c1_i64], [%arg2, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2x32xi64>>
22+
// expected-error @+2 {{expecting elem_size_in_bits * tile_width * v_blocks <= 1024}}
2323
// expected-error @+1 {{failed to legalize operation 'tt.load'}}
24-
%res = tt.load %22 {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<2x32xf32>>
24+
%res = tt.load %22 {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<2x32xi64>>
2525
tt.return
2626
}
2727
}
2828

2929
// -----
3030

3131
module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
32-
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f32>, %arg1: i64, %arg2: i32) {
32+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<i64>, %arg1: i64, %arg2: i32) {
3333
%c1_i64 = arith.constant 1 : i64
3434
%c0_i32 = arith.constant 0 : i32
35-
%cst = arith.constant dense<0.000000e+00> : tensor<2x32xf32>
36-
%22 = tt.make_tensor_ptr %arg0, [%arg1, %arg1], [%arg1, %c1_i64], [%arg2, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2x32xf32>>
37-
// expected-error @+2 {{expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
35+
%cst = arith.constant dense<0> : tensor<2x32xi64>
36+
%22 = tt.make_tensor_ptr %arg0, [%arg1, %arg1], [%arg1, %c1_i64], [%arg2, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2x32xi64>>
37+
// expected-error @+2 {{expecting elem_size_in_bits * tile_width * v_blocks <= 1024}}
3838
// expected-error @+1 {{failed to legalize operation 'tt.store'}}
39-
tt.store %22, %cst {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<2x32xf32>>
39+
tt.store %22, %cst {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<2x32xi64>>
4040
tt.return
4141
}
4242
}

test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,18 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
801801
llvm.return
802802
}
803803
}
804+
805+
// -----
806+
807+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
808+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
809+
// CHECK: llvm.mlir.constant(8 : i32) : i32
810+
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(8 : i32) : i32
811+
// CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(8 : i32) : i32
812+
// CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(4 : i32) : i32
813+
// CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32
814+
// CHECK-NEXT: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], {{.*}}, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
815+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=8, tile_height=4, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<2xi64>
816+
llvm.return
817+
}
818+
}

test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,18 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
193193
llvm.return
194194
}
195195
}
196+
197+
// -----
198+
199+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
200+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<2xi64>) {
201+
// CHECK: llvm.mlir.constant(0 : i32) : i32
202+
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(8 : i32) : i32
203+
// CHECK-DAG: [[TileWidth:%.*]] = llvm.mlir.constant(8 : i32) : i32
204+
// CHECK-DAG: [[TileHeight:%.*]] = llvm.mlir.constant(4 : i32) : i32
205+
// CHECK-DAG: [[VBlocks:%.*]] = llvm.mlir.constant(1 : i32) : i32
206+
// CHECK-NEXT: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], [[DEST:%.*]], {{.*}}, %arg2, %arg3, {{.*}}) {{.*}} : (i32, i32, i32, i32, !llvm.ptr{{.*}}, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>) -> ()
207+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits = 64, tile_width = 8, tile_height = 4, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<2xi64>)
208+
llvm.return
209+
}
210+
}

test/TritonGEN/tritongen-invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
174174
// -----
175175

176176
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
177-
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
178-
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=4, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<32xi16>
177+
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting elem_size_in_bits * tile_width * v_blocks <= 1024}}
178+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=32, tile_height=8, v_blocks=4, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<32xi16>
179179
llvm.return
180180
}
181181

@@ -501,8 +501,8 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei
501501
// -----
502502

503503
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
504-
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
505-
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
504+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting elem_size_in_bits * tile_width * v_blocks <= 1024}}
505+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
506506
llvm.return
507507
}
508508

test/TritonGEN/tritongen.mlir

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,33 @@ llvm.func @triton_gen.cache_controls(%arg0: !llvm.ptr) {
4545
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
4646
// CHECK: llvm.func @triton_gen.2Dblockload(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
4747
// CHECK-NEXT: %0 = triton_gen.2Dblockload %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 1, transpose = false, vnni_transform = false, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xf16>
48+
// CHECK-NEXT: %1 = triton_gen.2Dblockload %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 32, tile_width = 16, tile_height = 16, v_blocks = 1, transpose = false, vnni_transform = false, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xf32>
49+
// CHECK-NEXT: %2 = triton_gen.2Dblockload %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 64, tile_width = 8, tile_height = 16, v_blocks = 1, transpose = false, vnni_transform = false, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi64>
4850
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xf16>
51+
%1 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=16, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xf32>
52+
%2 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=8, tile_height=16, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi64>
4953
llvm.return
5054
}
5155

52-
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<16xf32>) {
53-
// CHECK: llvm.func @triton_gen.2Dblockstore(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: vector<16xf32>) {
54-
// CHECK-NEXT: triton_gen.2Dblockstore %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<16xf32>)
55-
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<16xf32>)
56+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val1 : vector<16xf16>, %stored_val2 : vector<16xf32>, %stored_val3 : vector<8xi64>) {
57+
// CHECK: llvm.func @triton_gen.2Dblockstore(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: vector<16xf16>, %arg7: vector<16xf32>, %arg8: vector<8xi64>) {
58+
// CHECK-NEXT: triton_gen.2Dblockstore %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6 {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<16xf16>)
59+
// CHECK-NEXT: triton_gen.2Dblockstore %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg7 {elem_size_in_bits = 32, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<16xf32>)
60+
// CHECK-NEXT: triton_gen.2Dblockstore %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg8 {elem_size_in_bits = 64, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi64>)
61+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val1 {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<16xf16>)
62+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val2 {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<16xf32>)
63+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val3 {elem_size_in_bits=64, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi64>)
5664
llvm.return
5765
}
5866

5967
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
6068
// CHECK: llvm.func @triton_gen.2Dblockprefetch(%arg0: !llvm.ptr, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32) {
69+
// CHECK-NEXT: triton_gen.2Dblockprefetch %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 16, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
6170
// CHECK-NEXT: triton_gen.2Dblockprefetch %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 32, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
71+
// CHECK-NEXT: triton_gen.2Dblockprefetch %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 {elem_size_in_bits = 64, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
72+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
6273
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
74+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
6375
llvm.return
6476
}
6577

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def TritonGEN_MatrixDPASOp : TritonGEN_Op<"dpas">,
154154
}
155155

156156
def TritonGEN_Matrix2DBlockLoadOp : TritonGEN_Op<"2Dblockload">,
157-
Results<(outs FixedVectorOfNonZeroRankOf<[TritonGEN_MatrixElemType]>:$res)>,
157+
Results<(outs FixedVectorOfNonZeroRankOf<[TritonGEN_MatrixElemType, AnyI64]>:$res)>,
158158
Arguments<(ins
159159
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
160160
I32:$base_width,
@@ -180,6 +180,7 @@ def TritonGEN_Matrix2DBlockLoadOp : TritonGEN_Op<"2Dblockload">,
180180
$base_width, $base_height, $base_pitch - the shape of matrix
181181
$x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to load
182182
$elem_size_in_bits - the size in bits of the matrix element
183+
- 64 for f64, i64
183184
- 32 for f32, bf32
184185
- 16 for f16, int16, bf16
185186
- 8 for int8, int4, int2
@@ -217,7 +218,7 @@ def TritonGEN_Matrix2DBlockStoreOp : TritonGEN_Op<"2Dblockstore">,
217218
I32Attr:$tile_width,
218219
I32Attr:$tile_height,
219220
I32Attr:$v_blocks,
220-
FixedVectorOfNonZeroRankOf<[TritonGEN_MatrixElemType]>:$stored_val,
221+
FixedVectorOfNonZeroRankOf<[TritonGEN_MatrixElemType, AnyI64]>:$stored_val,
221222
DefaultValuedAttr<TritonGEN_StoreCacheControl, "::mlir::triton::TritonGEN::StoreCacheControl::DEFAULT">:$cache_control
222223
)> {
223224

@@ -230,6 +231,7 @@ def TritonGEN_Matrix2DBlockStoreOp : TritonGEN_Op<"2Dblockstore">,
230231
$base_width, $base_height, $base_pitch - the shape of the matrix
231232
$x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store
232233
$elem_size_in_bits - the size in bits of the matrix element
234+
- 64 for f64, i64
233235
- 32 for f32, bf32
234236
- 16 for f16, int16, bf16
235237
- 8 for int8, int4, int2
@@ -274,6 +276,7 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
274276
$base_width, $base_height, $base_pitch - the shape of the matrix
275277
$x, $y, $tile_width, $tile_height - the starting offsets and shape of tile to prefetch
276278
$elem_size_in_bits - the size in bits of the matrix element
279+
- 64 for f64, i64
277280
- 32 for f32, bf32
278281
- 16 for f16, int16, bf16
279282
- 8 for int8, int4, int2

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,9 @@ template <typename Op> static LogicalResult verify2DBlockHWRestriction(Op op) {
102102

103103
uint32_t tileWidth = op.getTileWidth();
104104
uint32_t vBlocks = op.getVBlocks();
105-
if (elemSizeInBits * tileWidth * vBlocks > 512)
105+
if (elemSizeInBits * tileWidth * vBlocks > 1024)
106106
return op->emitOpError(
107-
"expecting elem_size_in_bits * tile_width * v_blocks <= 512");
107+
"expecting elem_size_in_bits * tile_width * v_blocks <= 1024");
108108

109109
assert(tileWidth >= 1 && tileWidth <= 64 &&
110110
"tile_width should be between 1 and 64");

0 commit comments

Comments
 (0)