-
Notifications
You must be signed in to change notification settings - Fork 75
[DOCUMENTS]Update the DPAS encoding documents. #2746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
eda4c40
7fc674a
ecfe006
17899d3
24aec77
be2a97b
4ea8bca
7f20255
e6ebfe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,52 +14,168 @@ def DpasEncodingAttr : DistributedEncoding<"DpasEncoding", "intel_dpas_encoding" | |
| let mnemonic = "dpas"; | ||
|
|
||
| let description = [{ | ||
| An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation. | ||
| An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation | ||
| and its corresponding A and B operands layout with the DPAS encoding as parent. | ||
| The XMX tensor core operation is defined for matrix matmul as: D=A*B+C | ||
| The shape of the of XMX tensor core operation is defined by systolic depth, repeat count, execution size and operations per channel. | ||
|
|
||
| The encoding is characterized by parameters: | ||
| - `repeatCount` which shall be in the range [1, 8] | ||
| - `systolicDepth` For PVC/ATSM, the size is 8. | ||
| - `executionSize` For PVC, the size is 16. For ATSM, the size is 8. | ||
| - `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type. | ||
| - `warpsPerCTA` | ||
| - `sugGroupSize` valid sub group size is 8/16/32 | ||
|
|
||
|
|
||
| The layout example repeat_count=8, systolic_depth=8, execution_size=16 and operands_per_chan=2 for warp size 32. | ||
| For A operand: | ||
| systolic depth = 8 | ||
| <-------------------------------------------------------------------------------------------------> | ||
| opsPerChan=2 | ||
| <---------> | ||
| t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 ^ | ||
| t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 | | ||
| t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 | | ||
| t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 | repeat count <= 8 | ||
| t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | | ||
| t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 | | ||
| t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 | | ||
| t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 v | ||
|
|
||
| For B operand: | ||
| execution size = 16 | ||
| <-------------------------------------------------------------> | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ ^ | ||
| . . . . . . . . . . . . . . . . | opsPerChan=2| | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v | | ||
| t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | | ||
| . . . . . . . . . . . . . . . . | | ||
| t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | systolic depth = 8 | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| . . . . . . . . . . . . . . . . | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | | ||
| . . . . . . . . . . . . . . . . | | ||
| t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v | ||
|
|
||
| This pattern repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks | ||
| along the row (resp. col) dimension. | ||
| - `opsPerChannel` 4 for 8 bit scalar type of A/B operands of DPAS instruction, | ||
| 2 for 16 bit scalar type of A/B operands of DPAS instruction, | ||
| 1 for 32 bit scalar type of A/B operands of DPAS instruction. | ||
| - `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2. | ||
| - `repCluster` indicates the cluster size of the repetitions of the DPAS tile. | ||
| - `threadsPerWarp_` AKA threadsPerWarp, use the name threadsPerWarp_ to avoid conflicting | ||
whitneywhtsang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| with the `getThreadsPerWarp` in interface DistributedLayout. Currently only 16 is supported. | ||
|
|
||
| The values of the matrix is distributed across the threads in the subgroup as row-major order. | ||
| - If the column size of the matrix is equal to the number of threads in the subgroup, one scalar represents one row of the matrix in register. | ||
| - If the column size of the matrix is less than the number of threads in the subgroup, one scalar represents multiple rows of the matrix in register. | ||
| - If the column size of the matrix is larger than the number of the threads in the subgroup, one scalar represents partial row of the matrix in register. | ||
|
|
||
| Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16. | ||
| The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16. | ||
|
|
||
| The layout for A operand: | ||
| K = 16 (K = systolic depth * opsPerChan) | ||
| <----------------------------------------------------------------------------> | ||
|
|
||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M = repeat count) | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v | ||
|
|
||
| The layout for B operand: | ||
| N = 16 (N = execution size) | ||
| <----------------------------------------------------------------------------> | ||
|
|
||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | K = 16 (K = systolic depth * opsPerChan) | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v | ||
|
|
||
| The layout for C operand and result D: | ||
| N = 16 (N = execution size) | ||
| <----------------------------------------------------------------------------> | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M = repeat count) | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v | ||
|
|
||
| Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as my previous comment. I think this fits but please confirm:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I update the example of the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. D should be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I update the document example just want to introduce the DPAS layout without the making it complicate when combine it with the tensor shape. An interesting fact is that both the |
||
| The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16. | ||
|
|
||
| The layout for A operand: | ||
| K = 8 (K = systolic depth * opsPerChan) | ||
| <----------------------------------------> | ||
|
|
||
| t0 t1 t2 t3 t4 t5 t6 t7 ^ | ||
| t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 | | ||
| t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 | M = 8 (M = repeat count) | ||
| t8 t9 t10 t11 t12 t13 t14 t15 | | ||
| t0 t1 t2 t3 t4 t5 t6 t7 | | ||
| t8 t9 t10 t11 t12 t13 t14 t15 v | ||
|
|
||
| The layouts for B operand is like the one of opsPerChan=2 but the K size is 8. | ||
| The layouts for C and D operands are same as the one of opsPerChan=2. | ||
|
|
||
| Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16. | ||
| The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16. | ||
|
|
||
| The layout for A operand: | ||
| K = 32 (K = systolic depth * opsPerChan) | ||
| <-----------------------------------------------------------------------------------------------------------------------------------> | ||
|
|
||
| t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 ^ | ||
| t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | | ||
| t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | | ||
| t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | | ||
| t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | M = 8 (M = repeat count) | ||
| t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | | ||
| t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | | ||
| t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 v | ||
|
|
||
| The layouts for B operand is like the one of opsPerChan=2 but the K size is 32. | ||
| The layouts for C and D operands are same as the one of opsPerChan=2. | ||
|
|
||
| The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks | ||
| along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing. | ||
|
|
||
| Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows: | ||
| ``` | ||
| #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}> | ||
| #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}> | ||
| #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> | ||
|
|
||
| %d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas> | ||
| ``` | ||
| The semantic of this `tt.dot` includes GEMM tiling configuration as: | ||
|
|
||
| warp[:0] warp[:1] warp[:0] warp[:1] | ||
| |----^----|----^----|----^----|----^----| | ||
| repCluster[1] | ||
| <---------> | ||
| ┌────┬────┬────┬────┬────┬────┬────┬────┐ | ||
| │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│ | ||
| │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│ | ||
| warpPerCTA = [[W0, W1], ├────┼────┼────┼────┼────┼────┼────┼────┤ | ||
| [W2, W3]] │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│ | ||
| │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│ | ||
| └────┴────┴────┴────┴────┴────┴────┴────┘ | ||
|
|
||
|
|
||
| - ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐ | ||
| | | │W0R0│W0R2│ │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│ | ||
| | | │W1R0│W1R2│ │ │ │ │ │ │ │ │ │ | ||
| warp[0:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | ||
| | | │W0R1│W0R3│ │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│ | ||
| | | │W1R1│W1R3│ │ │ │ │ │ │ │ │ │ | ||
| - v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | ||
| | │W2R0│W2R2│ │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│ | ||
| | │W3R0│W3R2│ │ │ │ │ │ │ │ │ │ | ||
| warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | ||
| | │W2R1│W2R1│ │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│ | ||
| | │W3R1│W3R1│ │ │ │ │ │ │ │ │ │ | ||
| - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | ||
| | │W0R4│W0R6│ │W0R8│W0R9│W1R8│W1R9│W0 │W0 │W1 │W1 │ | ||
| | │W1R4│W1R6│ │ │ │ │ │R12 │R13 │R12 │R13 │ | ||
| warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | ||
| | │W0R5│W0R7│ │W0 │W0 │W1 │W1 │W0 │W0 │W1 │W1 │ | ||
| | │W1R5│W1R7│ │R10 │R11 │R10 │R11 │R14 │R15 │R14 │R15 │ | ||
| - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | ||
| | │W2R4│W2R6│ │W2R8│W2R9│W3R8│W3R8│W2 │W2 │W3 │W3 │ | ||
| | │W3R4│W3R6│ │ │ │ │ │R12 │R13 │R12 │R13 │ | ||
| warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ | ||
| | │W2R5│W2R7│ │W2 │W2 │W3 │W3 │W2 │W2 │W3 │W3 │ | ||
| | │W3R5│W3R7│ │R10 │R11 │R10 │R10 │R14 │R15 │R14 │R15 │ | ||
| - └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘ | ||
|
|
||
|
|
||
| }]; | ||
|
|
||
| let parameters = ( | ||
|
|
@@ -70,7 +186,7 @@ along the row (resp. col) dimension. | |
| "unsigned":$opsPerChannel, | ||
| ArrayRefParameter<"unsigned">:$warpsPerCTA__, | ||
| ArrayRefParameter<"unsigned">:$repCluster, | ||
| "unsigned":$subGroupSize | ||
| "unsigned":$threadsPerWarp_ | ||
| ); | ||
|
|
||
| let extraClassDeclaration = extraDistributedDeclaration # [{ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.