diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 1cc88123d0..e4360b762f 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -34,6 +34,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_INTEL_ADVANCED_PATH", "TRITON_INTEL_AGGRESSIVE_DPAS_REUSE", "TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN", + "TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B", "TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT", "TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM", "TRITON_INTEL_ENABLE_INSTR_SCHED", diff --git a/test/TritonIntelGPU/blockptr_load.mlir b/test/TritonIntelGPU/blockptr_load.mlir index 63bae0b4c1..ff99bbf77f 100644 --- a/test/TritonIntelGPU/blockptr_load.mlir +++ b/test/TritonIntelGPU/blockptr_load.mlir @@ -1,4 +1,5 @@ -// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B +// RUN: TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B // CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} // CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} @@ -204,22 +205,25 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war %c0_i32 = arith.constant 0 : i32 %c32_i64 = arith.constant 32 : i64 %21 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array} : >> - // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () - // CHECK: %[[VAL_68:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> - // CHECK: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> - // CHECK: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> - // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () - // CHECK: %[[VAL_103:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> - // CHECK: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> - // CHECK: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> - // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () - // CHECK: %[[VAL_138:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> - // CHECK: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> - // CHECK: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> - // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () - // CHECK: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> - // CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> - // CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> + // COM: One DPAS operand B per load instruction. + // SMALL-BLOCK-SIZE-TRANS-B-COUNT-8: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // COM: Two interleaved DPAS operand B per load instruction. Need to shuffle the loaded value to decompose the VNNI format DPAS operand B. + // LARGE-BLOCK-SIZE-TRANS-B: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_68:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_103:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_138:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32> + // LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32> %45 = tt.load %21 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr>> tt.return } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index cf2475c41b..35eb540241 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -627,13 +627,19 @@ struct LoadOpConversion std::swap(tileHeight, tileWidth); - // We can decompose the matrix returned by transposed large 2d load - // when threads per warp < column size. Otherwise we have to load one - // operand per inst. - // Note: the tileHeight and numOperandsPer2DLoadM are the column size - // now. - numOperandsPer2DLoadM = - (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1; + if (triton::tools::getBoolEnv( + "TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B")) { + // Only load 1 operand per inst on row. + numOperandsPer2DLoadM = 1; + } else { + // We can decompose the matrix returned by transposed large 2d load + // when threads per warp < column size. Otherwise we have to load one + // operand per inst. + // Note: the tileHeight and numOperandsPer2DLoadM are the column size + // now. + numOperandsPer2DLoadM = + (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1; + } // The transpose 2d load only support 1 operand per inst on column. // (vBlocks = 1) numOperandsPer2DloadN = 1;