Skip to content

Conversation

@chengjunlu
Copy link
Contributor

To load single DPAS B matrix per 2D block io instruction from the column major matrix in memory gets better performance for flash attention.

Because unlike the row major matrix, the values, which includes more than one DPAS B operands returned by a single 2D transposed block IO, cannot be used as DPAS operands directly.

We have to shuffle the value in the register before pass it to the DPAS instruction and this is not optimized by the IGC for now.

@chengjunlu chengjunlu changed the title To load single DPAS B matrix instead of two per 2D block io instruction from transposed memory Nov 5, 2024
@chengjunlu chengjunlu changed the title To load single DPAS B matrix instead of two per 2D block io instruction from transposed memory To load single DPAS B matrix instead of two per 2D block io instruction from the transposed memory Nov 5, 2024
// CHECK: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
// CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
// CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
// CHECK-COUNT-8: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that there is no llvm.shufflevector?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should

Copy link
Contributor

@etiotto etiotto Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the code generated does still contain a shufflevector instruction, example:

    llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(%30, %139, %33, %138, %146, %141) {arg_attrs = [{llvm.nonnull, llvm.readonly}, {}, {}, {}, {}, {llvm.nonnull, llvm.writeonly}], function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>, linkage = #llvm.linkage<external>, no_unwind, sym_name = "_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64, will_return} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
    %147 = llvm.load %141 : !llvm.ptr -> vector<8xi32>
    %148 = llvm.shufflevector %147, %147 [0, 1, 2, 3, 4, 5, 6, 7] : vector<8xi32> 
    %149 = llvm.bitcast %148 : vector<8xi32> to vector<16xf16>

Note: the remaining shufflevector instruction is a noop (yields the same input vector) so it could be removed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the lit test.

@alexbaden
Copy link
Contributor

alexbaden commented Nov 5, 2024

I am seeing a ~5% regression in AxBT performance and a ~2.5% regression in ATxBT performance with this change. Is that expected?

Before:

Geomean    
Torch (ms) Triton (ms) Triton % Torch
0.316446402 0.393551599 80.41%
0.338533627 0.66412835 50.97%
0.318394668 1.600593032 19.89%
0.456415746 1.986145424 22.98%

After:

Torch (ms) Triton (ms) Triton % Torch
0.314741015 0.388807784 80.95%
0.337140929 0.751260794 44.88%
0.313591856 1.608697493 19.49%
0.451715958 2.195144218 20.58%
LIBIGC1_VERSION=1.0.17791.9-1029
LEVEL_ZERO_VERSION=1.17.44.0-1022
AGAMA_VERSION=1029
GPU_DEVICE=Intel(R) Data Center GPU Max 1100
TORCH_VERSION=2.6.0
COMPILER_VERSION=2024.1.4

// now.
numOperandsPer2DLoadM =
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
// Only load 1 operand per inst on row.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add the reason (an explanation) for doing so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason as I understand it is that by loading smaller 2D blocks we do not have to shuffle the results of the loads before "feeding" it to the DPAS instruction. @chengjunlu can you elaborate please.

Copy link
Contributor Author

@chengjunlu chengjunlu Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason is to reduce the register spilling size.
The large size of 2D load requires a large contiguous register spaces.
Unlike the non-transpose matrix which can be used directly to the DPAS engine with a sub-region of the values loaded, the transposed matrix operands has to be shuffled first.
Then it requires another temporary space for shuffling which makes the register allocation more complicate.

The backend cannot generate the 0 spill binary if we use the large size 2D load of transposed matrix operands for now (Even with the latest internal IGC quick build.).

// CHECK: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
// CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
// CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
// CHECK-COUNT-8: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should

@etiotto
Copy link
Contributor

etiotto commented Nov 5, 2024

I did a micro benchmark run, results: http://benchmarks.glados.intel.com/d/1pXX4hUSz/microbenchmarks?orgId=1&var-tag=ci%7Creorder_dpas_gen&var-bench=All&var-device=Intel%28R%29%20Data%20Center%20GPU%20Max%201550&var-compiler=triton&var-backend=All&var-baseline_backend=triton-ci-XPU%201550&var-target_backend=triton-ci-XPU%201550

I was expecting to "See" and improvement for gemm with operand B transposed but performance seems about the same as CI (without this PR). We should grab the latest IGC and test against that.

image

@alexbaden
Copy link
Contributor

The degradation appears to be due to the number of loads doubling for the B matrix. Previously we made 4 2D block read transpose loads:

  call spir_func void @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj(i8 addrspace(1)* readonly %1, i32 10240, i32 4096, i32 10240, <2 x i32> %215, i32* %216) #0

now we make 8:

  call spir_func void @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(i8 addrspace(1)* readonly %1, i32 10240, i32 4096, i32 10240, <2 x i32> %219, i32* %220) #0

b/c the loads are smaller - half the number of rows - we need to do twice as many.

Does something change with a newer IGC version that mitigates this?

// CHECK: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
// CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
// CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
// CHECK-COUNT-8: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: these are the changes that matter. The rest of the changes are white space differences.

// now.
numOperandsPer2DLoadM =
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
// Only load 1 operand per inst on row.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason as I understand it is that by loading smaller 2D blocks we do not have to shuffle the results of the loads before "feeding" it to the DPAS instruction. @chengjunlu can you elaborate please.

@etiotto etiotto requested a review from alexbaden November 5, 2024 19:52
@etiotto
Copy link
Contributor

etiotto commented Nov 6, 2024

Does something change with a newer IGC version that mitigates this?

Yes, the rationale for this change is discussed here: #2628 (comment)

@chengjunlu chengjunlu force-pushed the chengjun/load_one_DPAS_B_T branch 2 times, most recently from 013c838 to ae0c986 Compare November 7, 2024 02:42
… transposed DPAS operand B. It can get better performance for flash attention because it works around performance issue of register spilling.
Copy link
Contributor

@etiotto etiotto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the env. variable. This feature should be enabled by default, having proven that it gives good performance for attention and for GEMM with B transpose on the latest dev IGC.

@chengjunlu please provide performance results for this PR:

  1. run with DEV IGC driver (annotate the exact build please) without your code (baseline)
  2. rerun on the same machine with the same IGC dev driver as in (1) and with your feature (new)

Do so for attention and for GEMM with B transpose. Then record the results in this PR or in the issue associated with this PR please.

"TRITON_INTEL_ADVANCED_PATH",
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

@chengjunlu
Copy link
Contributor Author

I uses two IGC driver version to run the benchmark because:

  1. Using rolling IGC agama 1032 for GEMM AxB.T because there is a functional issue in DEV IGC for running the GEMM AxB.T.
  2. Using DEV IGC because it includes the latest enhancement for flash attention.

For the GEMM benchmark, the large block IO size is 1.09x better than the small block IO size in geomean.
For the flash attention benchmark, the large block IO size is 0.80x worse than the small block IO size in geomean.

The main reason of the small block IO size getting better performance in flash attention is because of there is no register spilling comparing to the large block IO size.

Here is the detail performance results for small/large block IO size comparison:
GEMM AxB.T

  B M K N Tflops (small block IO) Triton (large block IO) ratio
0 1 1024 1024 1024 47.892141 53.340376 1.113760523
1 1 2048 2048 2048 81.4289 92.454361 1.135399852
2 1 4096 4096 4096 106.56165 121.938171 1.144296949
3 1 8192 8192 8192 104.713196 116.605608 1.113571283
4 1 1 5120 13824 0.353559 0.351434 0.993989688
5 1 4 4096 12288 1.265171 1.264138 0.99918351
6 1 512 8192 8192 67.922074 74.869233 1.102281314
7 1 512 8192 32768 88.565153 99.414784 1.122504514
8 1 512 32768 8192 61.549843 69.834026 1.134593081
9 1 1024 16384 8192 71.986588 82.35006 1.143963928
10 1 1024 28672 8192 70.379633 71.105156 1.010308707
11 1 3072 4096 3072 96.256542 109.823859 1.140949557
12 1 4096 16384 8192 83.23504 98.109532 1.17870469
13 1 8192 16384 1024 75.496408 90.432862 1.19784324
14 1 8192 16384 4096 88.248834 98.666475 1.118048483
15 1 16384 1024 8192 122.578733 127.232374 1.037964506
16 1 16384 4096 8192 116.408825 122.843592 1.055277313
17 1 16384 8192 1024 105.810989 120.147341 1.135490199
18 1 16384 8192 4096 110.823567 123.97636 1.118682275
19 4 32768 128 4096 50.258884 50.281316 1.000446329
20 4 32768 4096 128 52.265309 61.017267 1.167452526
21 32 4096 4096 128 50.609409 53.296938 1.053103347
22 4096 8 128 16384 4.26713 4.285785 1.004371791
23 4096 8 16384 128 4.34589 4.354827 1.002056426

Flash Attention:

  Z H N_CTX D_HEAD CAUSAL Tflops (small block IO) Triton (large block IO) ratio
0 1 32 16384 64 FALSE 78.212688 61.394674 0.784970771
1 2 32 8192 64 FALSE 76.707298 59.493067 0.775585486
2 4 32 4096 64 FALSE 72.957003 57.576184 0.789179676
3 8 32 2048 64 FALSE 67.13312 54.010944 0.80453499
4 16 32 1024 64 FALSE 60.617356 50.008354 0.824984085
5 32 32 512 64 FALSE 62.716272 51.821513 0.826284971
6 4 48 1024 64 FALSE 70.047579 58.562413 0.836037645

@etiotto
Copy link
Contributor

etiotto commented Nov 12, 2024

So in summary the reason for not using large 2D reads is to make it easier for IGC to allocate registers without spills. But we do not have a good heuristic for register pressure in Triton. We need to discuss how to proceed offline on this.

@etiotto
Copy link
Contributor

etiotto commented Nov 12, 2024

Discussed offline, decided to merge this PR with the env. variable so that we can more easily run performance experiments.

@etiotto etiotto merged commit ee755e8 into main Nov 12, 2024
5 checks passed
@etiotto etiotto deleted the chengjun/load_one_DPAS_B_T branch November 12, 2024 14:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance] Improve the flash attention performance on bottom-up optimization pipeline

5 participants