@@ -37,6 +37,34 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
3737
3838// -----
3939
40+ llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
41+ // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r16x1cPU3AS1viiiDv2_iPh(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
42+ // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi8>
43+ %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =8 , tile_width =16 , tile_height =8 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr <1 >, i32 , i32 , i32 , i32 , i32 ) -> vector <8 xi8 >
44+ llvm.return
45+ }
46+
47+ // -----
48+
49+ llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
50+ // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x1cPU3AS1viiiDv2_iPh(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
51+ // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi8>
52+ %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =8 , tile_width =32 , tile_height =8 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr <1 >, i32 , i32 , i32 , i32 , i32 ) -> vector <16 xi8 >
53+ llvm.return
54+ }
55+
56+ // -----
57+
58+ // COM: This case come from the 06 tutorial of FP8 flash attention.
59+ llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
60+ // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r16x4cPU3AS1viiiDv2_iPh(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
61+ // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi8>
62+ %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =8 , tile_width =16 , tile_height =8 , v_blocks =4 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr <1 >, i32 , i32 , i32 , i32 , i32 ) -> vector <32 xi8 >
63+ llvm.return
64+ }
65+
66+ // -----
67+
4068llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
4169 // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
4270 // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<8xi16>
@@ -64,6 +92,15 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
6492
6593// -----
6694
95+ llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
96+ // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r32x1cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
97+ // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16>
98+ %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =16 , tile_width =32 , tile_height =8 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr <1 >, i32 , i32 , i32 , i32 , i32 ) -> vector <16 xi16 >
99+ llvm.return
100+ }
101+
102+ // -----
103+
67104llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
68105 // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_32b_8r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
69106 // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<4xi32>
@@ -101,12 +138,22 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
101138
102139llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
103140 // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_32r8x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
141+ // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
104142 %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =32 , tile_width =8 , tile_height =32 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr <1 >, i32 , i32 , i32 , i32 , i32 ) -> vector <16 xi32 >
105143 llvm.return
106144}
107145
108146// -----
109147
148+ llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
149+ // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_32b_8r2x1cPU3AS1viiiDv2_iPj(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
150+ // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<1xi32>
151+ %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =32 , tile_width =2 , tile_height =8 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr <1 >, i32 , i32 , i32 , i32 , i32 ) -> vector <1 xi32 >
152+ llvm.return
153+ }
154+
155+ // -----
156+
110157llvm.func @triton_gen.2Dblockload (%ptr : !llvm.ptr <1 >, %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
111158 // CHECK: llvm.call spir_funccc @_Z40intel_sub_group_2d_block_read_8b_8r32x2cPU3AS1viiiDv2_iPt(%arg0, %arg1, %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
112159 // CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi16>
@@ -333,84 +380,3 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
333380 %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =8 , tile_width =32 , tile_height =8 , v_blocks =2 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr <1 >, i32 , i32 , i32 , i32 , i32 ) -> vector <16 xi16 >
334381 llvm.return
335382}
336-
337- // -----
338-
339- // CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i8
340- // CHECK-LABEL: llvm.func @matrix_2Dblockload
341- llvm.func @matrix_2Dblockload (%ptr : !llvm.ptr , %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
342- // CHECK: [[ELEM_SIZE_IN_BITS:%.*]] = llvm.mlir.constant(8 : i32) : i32
343- // CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(16 : i32) : i32
344- // CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
345- // CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
346- // CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
347- // CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
348- // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v8i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_SIZE_IN_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
349- %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =8 , tile_width =16 , tile_height =8 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr , i32 , i32 , i32 , i32 , i32 ) -> vector <8 xi8 >
350- llvm.return
351- }
352-
353- // -----
354-
355- // COM: This case come from the 06 tutorial of FP8 flash attention.
356- // CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i8
357- // CHECK-LABEL: llvm.func @matrix_2Dblockload
358- llvm.func @matrix_2Dblockload (%ptr : !llvm.ptr , %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
359- // CHECK: [[ELEM_SIZE_IN_BITS:%.*]] = llvm.mlir.constant(8 : i32) : i32
360- // CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(16 : i32) : i32
361- // CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
362- // CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(4 : i32) : i32
363- // CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
364- // CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
365- // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_SIZE_IN_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
366- %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =8 , tile_width =16 , tile_height =8 , v_blocks =4 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr , i32 , i32 , i32 , i32 , i32 ) -> vector <32 xi8 >
367- llvm.return
368- }
369-
370- // -----
371-
372- // CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i8
373- // CHECK-LABEL: llvm.func @matrix_2Dblockload
374- llvm.func @matrix_2Dblockload (%ptr : !llvm.ptr , %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
375- // CHECK: [[ELEM_SIZE_IN_BITS:%.*]] = llvm.mlir.constant(8 : i32) : i32
376- // CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(32 : i32) : i32
377- // CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
378- // CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
379- // CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
380- // CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
381- // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_SIZE_IN_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
382- %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =8 , tile_width =32 , tile_height =8 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr , i32 , i32 , i32 , i32 , i32 ) -> vector <16 xi8 >
383- llvm.return
384- }
385-
386- // -----
387-
388- // CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i16
389- // CHECK-LABEL: llvm.func @matrix_2Dblockload
390- llvm.func @matrix_2Dblockload (%ptr : !llvm.ptr , %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
391- // CHECK: [[ELEM_SIZE_IN_BITS:%.*]] = llvm.mlir.constant(16 : i32) : i32
392- // CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(32 : i32) : i32
393- // CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
394- // CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
395- // CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
396- // CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
397- // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_SIZE_IN_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
398- %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =16 , tile_width =32 , tile_height =8 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr , i32 , i32 , i32 , i32 , i32 ) -> vector <16 xi16 >
399- llvm.return
400- }
401-
402- // -----
403-
404- // CHECK: llvm.func spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v1i32
405- // CHECK-LABEL: llvm.func @matrix_2Dblockload
406- llvm.func @matrix_2Dblockload (%ptr : !llvm.ptr , %base_width : i32 , %base_height : i32 , %base_pitch : i32 , %x : i32 , %y : i32 ) {
407- // CHECK: [[ELEM_SIZE_IN_BITS:%.*]] = llvm.mlir.constant(32 : i32) : i32
408- // CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(2 : i32) : i32
409- // CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
410- // CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
411- // CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
412- // CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
413- // CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v1i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_SIZE_IN_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
414- %0 = triton_gen.2Dblockload %ptr , %base_width , %base_height , %base_pitch , %x , %y {elem_size_in_bits =32 , tile_width =2 , tile_height =8 , v_blocks =1 , transpose =false , vnni_transform =false , cache_control =Default } : (!llvm.ptr , i32 , i32 , i32 , i32 , i32 ) -> vector <1 xi32 >
415- llvm.return
416- }
0 commit comments