From eda4c408f075d76bf5cdbdb0e93c6b4f93ce5bcf Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Tue, 19 Nov 2024 11:57:43 +0000 Subject: [PATCH 1/8] Update the DPAS encoding documents. --- .../IR/TritonIntelGPUAttrDefs.td | 181 ++++++++++++++---- 1 file changed, 143 insertions(+), 38 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index bb456eaf38..b6bf5b3bea 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -14,7 +14,8 @@ def DpasEncodingAttr : DistributedEncoding<"DpasEncoding", "intel_dpas_encoding" let mnemonic = "dpas"; let description = [{ -An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation. +An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation +and its corresponding A and B operands layout with the DPAS encoding as parent. The XMX tensor core operation is defined for matrix matmul as: D=A*B+C The shape of the of XMX tensor core operation is defined by systolic depth, repeat count, execution size and operations per channel. @@ -23,43 +24,147 @@ The encoding is characterized by parameters: - `systolicDepth` For PVC/ATSM, the size is 8. - `executionSize` For PVC, the size is 16. For ATSM, the size is 8. - `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type. - - `warpsPerCTA` - - `sugGroupSize` valid sub group size is 8/16/32 - - -The layout example repeat_count=8, systolic_depth=8, execution_size=16 and operands_per_chan=2 for warp size 32. -For A operand: - systolic depth = 8 -<-------------------------------------------------------------------------------------------------> -opsPerChan=2 -<---------> -t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 ^ -t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 | -t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 | -t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 | repeat count <= 8 -t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | -t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 | -t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 | -t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 v - -For B operand: - execution size = 16 -<-------------------------------------------------------------> -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ ^ -. . . . . . . . . . . . . . . . | opsPerChan=2| -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v | -t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | -. . . . . . . . . . . . . . . . | -t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | systolic depth = 8 -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | -. . . . . . . . . . . . . . . . | -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | -t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | -. . . . . . . . . . . . . . . . | -t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v - -This pattern repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks -along the row (resp. col) dimension. + - `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2. + - `repCluster` indicates the cluster size of the repetitions of the DPAS tile. + - `sugGroupSize` Currently only sub group size 16 is supported. + +The values of the matrix is distributed across the threads in the subgroup as row-major order. + - If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix. + - If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix. + - If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name. + +Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16. + +The layout for A operand: + K = 16 (K = systolic depth * opsPerChan) +<----------------------------------------------------------------------------> + +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count) +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v + +The layout for B operand: + N = 16 (N = execution size) +<----------------------------------------------------------------------------> + +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | K = 16 (K = systolic depth * opsPerChan) +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v + +The layout for C operand and result D: + N = 16 (N = execution size) +<----------------------------------------------------------------------------> +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M=repeat count) +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v + +Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16. + +The layout for A operand: + K = 8 (K = systolic depth * opsPerChan) +<----------------------------------------> + +t0 t1 t2 t3 t4 t5 t6 t7 ^ +t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 | +t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 | M = 8 (repeat count) +t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 | +t8 t9 t10 t11 t12 t13 t14 t15 v + +The layouts for B operand is like the one of opsPerChan=2 but the K size is 8. +The layouts for C and D operands are same as the one of opsPerChan=2. + +Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16. + +The layout for A operand: + K = 32 (K = systolic depth * opsPerChan) +<-----------------------------------------------------------------------------------------------------------------------------------> + +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 ^ +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | M = 8 (repeat count) +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 v + +The layouts for B operand is like the one of opsPerChan=2 but the K size is 32. +The layouts for C and D operands are same as the one of opsPerChan=2. + +The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks +along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing. + +Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16. +The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16. +The DPAS repetitions are distributed as follows: + + warp[:0] warp[:1] warp[:0] warp[:1] + |----^----|----^----|----^----|----^----| + repCluster[1] + <---------> + ┌────┬────┬────┬────┬────┬────┬────┬────┐ + │R0 │R1 │ │ │R4 │R5 │ │ │ + │ │ │ │ │ │ │ │ │ + ├────┼────┼────┼────┼────┼────┼────┼────┤ + │R2 │R3 │ │ │R6 │R7 │ │ │ + │ │ │ │ │ │ │ │ │ + └────┴────┴────┴────┴────┴────┴────┴────┘ + + - ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐ + | | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │ + | | │ │ │ │ │ │ │ │ │ │ │ │ + warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │ + | | │ │ │ │ │ │ │ │ │ │ │ │ + - v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │ │ │ │ │ │ │ │ │ │ │ │ + | │ │ │ │ │ │ │ │ │ │ │ │ + warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │ │ │ │ │ │ │ │ │ │ │ │ + | │ │ │ │ │ │ │ │ │ │ │ │ + - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │ + | │ │ │ │ │ │ │ │ │ │ │ │ + warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │ + | │ │ │ │ │ │ │ │ │ │ │ │ + - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │ │ │ │ │ │ │ │ │ │ │ │ + | │ │ │ │ │ │ │ │ │ │ │ │ + warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │ │ │ │ │ │ │ │ │ │ │ │ + | │ │ │ │ │ │ │ │ │ │ │ │ + - └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘ + }]; let parameters = ( From 7fc674a24864018038f2f6477428d5d05ca05b05 Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Thu, 21 Nov 2024 11:42:31 +0000 Subject: [PATCH 2/8] Update the documents based on review comments. --- .../IR/TritonIntelGPUAttrDefs.td | 115 ++++++++++-------- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 14 +-- .../ConvertLayoutOpToLLVM.cpp | 2 +- .../TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp | 2 +- .../intel/lib/TritonIntelGPUToLLVM/Utility.h | 3 +- .../OptimizeReductionLocality.cpp | 6 +- 6 files changed, 78 insertions(+), 64 deletions(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index b6bf5b3bea..9208769237 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -23,18 +23,21 @@ The encoding is characterized by parameters: - `repeatCount` which shall be in the range [1, 8] - `systolicDepth` For PVC/ATSM, the size is 8. - `executionSize` For PVC, the size is 16. For ATSM, the size is 8. - - `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type. + - `opsPerChannel` 4 for 8 bit scalar type of A/B operands of DPAS instruction, + 2 for 16 bit scalar type of A/B operands of DPAS instruction, + 1 for 32 bit scalar type of A/B operands of DPAS instruction. - `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2. - `repCluster` indicates the cluster size of the repetitions of the DPAS tile. - - `sugGroupSize` Currently only sub group size 16 is supported. + - `threadsPerWarp_` AKA threadsPerWarp, use the name threadsPerWarp_ to avoid conflicting + with the `getThreadsPerWarp` in interface DistributedLayout. Currently only 16 is supported. The values of the matrix is distributed across the threads in the subgroup as row-major order. - - If the column size of the matrix is equal to the number of threads in the subgroup, a single value name represents a single rows of the matrix. - - If the column size of the matrix is less than the number of threads in the subgroup, a single value name represents multiple rows of the matrix. - - If the column size of the matrix is larger than the number of the threads in the subgroup, a single row of the matrix requires multiple value name. + - If the column size of the matrix is equal to the number of threads in the subgroup, one scalar represents one row of the matrix in register. + - If the column size of the matrix is less than the number of threads in the subgroup, one scalar represents multiple rows of the matrix in register. + - If the column size of the matrix is larger than the number of the threads in the subgroup, one scalar represents partial row of the matrix in register. Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16. -The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and sugGroupSize=16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16. The layout for A operand: K = 16 (K = systolic depth * opsPerChan) @@ -83,7 +86,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16. -The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and sugGroupSize=16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16. The layout for A operand: K = 8 (K = systolic depth * opsPerChan) @@ -102,7 +105,7 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 8. The layouts for C and D operands are same as the one of opsPerChan=2. Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16. -The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and sugGroupSize=16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16. The layout for A operand: K = 32 (K = systolic depth * opsPerChan) @@ -121,49 +124,57 @@ The layouts for B operand is like the one of opsPerChan=2 but the K size is 32. The layouts for C and D operands are same as the one of opsPerChan=2. The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks -along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing. - -Suppose we have a `tt.dot` operation of the block size [64, 128] += [64, 32] * [32, 128] of hf16/bf16. -The `warpsPerCTA` set to [2, 2]. The number of repetitions of the DPAS tile per warp is: A=8, B=8, C,D=16. -The DPAS repetitions are distributed as follows: - - warp[:0] warp[:1] warp[:0] warp[:1] - |----^----|----^----|----^----|----^----| - repCluster[1] - <---------> - ┌────┬────┬────┬────┬────┬────┬────┬────┐ - │R0 │R1 │ │ │R4 │R5 │ │ │ - │ │ │ │ │ │ │ │ │ - ├────┼────┼────┼────┼────┼────┼────┼────┤ - │R2 │R3 │ │ │R6 │R7 │ │ │ - │ │ │ │ │ │ │ │ │ - └────┴────┴────┴────┴────┴────┴────┴────┘ - - - ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐ - | | │R0 │R2 │ │R0 │R1 │ │ │R4 │R5 │ │ │ - | | │ │ │ │ │ │ │ │ │ │ │ │ - warp[0:] < repCluster[0] | ]────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ - | | │R1 │R3 │ │R2 │R3 │ │ │R6 │R7 │ │ │ - | | │ │ │ │ │ │ │ │ │ │ │ │ - - v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ - | │ │ │ │ │ │ │ │ │ │ │ │ - | │ │ │ │ │ │ │ │ │ │ │ │ - warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ - | │ │ │ │ │ │ │ │ │ │ │ │ - | │ │ │ │ │ │ │ │ │ │ │ │ - - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ - | │R4 │R6 │ │R8 │R9 │ │ │R12 │R13 │ │ │ - | │ │ │ │ │ │ │ │ │ │ │ │ - warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ - | │R5 │R7 │ │R10 │R11 │ │ │R14 │R15 │ │ │ - | │ │ │ │ │ │ │ │ │ │ │ │ - - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ - | │ │ │ │ │ │ │ │ │ │ │ │ - | │ │ │ │ │ │ │ │ │ │ │ │ - warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ - | │ │ │ │ │ │ │ │ │ │ │ │ - | │ │ │ │ │ │ │ │ │ │ │ │ - - └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘ +along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing. + +Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows: +``` +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> + +%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas> +``` +The semantic of this `tt.dot` includes GEMM tiling configuration as: + + warp[:0] warp[:1] warp[:0] warp[:1] + |----^----|----^----|----^----|----^----| + repCluster[1] + <---------> + ┌────┬────┬────┬────┬────┬────┬────┬────┐ + │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│ + │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│ + warpPerCTA = [[W0, W1], ├────┼────┼────┼────┼────┼────┼────┼────┤ + [W2, W3]] │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│ + │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│ + └────┴────┴────┴────┴────┴────┴────┴────┘ + + + - ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐ + | | │W0R0│W0R2│ │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│ + | | │W1R0│W1R2│ │ │ │ │ │ │ │ │ │ + warp[0:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | | │W0R1│W0R3│ │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│ + | | │W1R1│W1R3│ │ │ │ │ │ │ │ │ │ + - v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W2R0│W2R2│ │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│ + | │W3R0│W3R2│ │ │ │ │ │ │ │ │ │ + warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W2R1│W2R1│ │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│ + | │W3R1│W3R1│ │ │ │ │ │ │ │ │ │ + - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W0R4│W0R6│ │W0R8│W0R9│W1R8│W1R9│W0 │W0 │W1 │W1 │ + | │W1R4│W1R6│ │ │ │ │ │R12 │R13 │R12 │R13 │ + warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W0R5│W0R7│ │W0 │W0 │W1 │W1 │W0 │W0 │W1 │W1 │ + | │W1R5│W1R7│ │R10 │R11 │R10 │R11 │R14 │R15 │R14 │R15 │ + - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W2R4│W2R6│ │W2R8│W2R9│W3R8│W3R8│W2 │W2 │W3 │W3 │ + | │W3R4│W3R6│ │ │ │ │ │R12 │R13 │R12 │R13 │ + warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W2R5│W2R7│ │W2 │W2 │W3 │W3 │W2 │W2 │W3 │W3 │ + | │W3R5│W3R7│ │R10 │R11 │R10 │R10 │R14 │R15 │R14 │R15 │ + - └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘ + }]; @@ -175,7 +186,7 @@ The DPAS repetitions are distributed as follows: "unsigned":$opsPerChannel, ArrayRefParameter<"unsigned">:$warpsPerCTA__, ArrayRefParameter<"unsigned">:$repCluster, - "unsigned":$subGroupSize + "unsigned":$threadsPerWarp_ ); let extraClassDeclaration = extraDistributedDeclaration # [{ diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index baa0e3e347..92857c2c58 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -134,7 +134,7 @@ SmallVector DpasEncodingAttr::getShapeC() const { SmallVector DpasEncodingAttr::getSizePerThread() const { size_t rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); - unsigned threadsPerWarp = getSubGroupSize(); + unsigned threadsPerWarp = getThreadsPerWarp_(); auto shapeC = getDPASInstShapeC(); unsigned elemsNum = product(shapeC); unsigned elemsPerThread = elemsNum / threadsPerWarp; @@ -260,7 +260,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, mlir::Type eltTy, int kWidth, int opIdx) const { auto shapePerCTA = getShapePerCTA(*this, shape); auto rep = getDPASRepetitions(shapePerCTA, opIdx); - auto threadsPerWar = getSubGroupSize(); + auto threadsPerWar = getThreadsPerWarp_(); size_t rank = shape.size(); if (opIdx == 0) { auto shapeA = getShapeA(); @@ -296,7 +296,7 @@ SmallVector DpasEncodingAttr::getThreadsPerWarp() const { size_t rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); auto executionSize = getExecutionSize(); - auto subGroupSize = getSubGroupSize(); + auto subGroupSize = getThreadsPerWarp_(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be " "smaller than the execution size"); @@ -340,7 +340,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const { assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout"); if (opIdx == 0) { SmallVector shapeA = getDPASInstShapeA(); - unsigned subGroupSize = getSubGroupSize(); + unsigned subGroupSize = getThreadsPerWarp_(); unsigned opsPerChannel = getOpsPerChannel(); // pack the value to i16 for scalar bit width <=16. @@ -359,7 +359,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, unsigned opIdx) const { if (opIdx == 1) { auto shapeB = getShapeB(); - auto subGroupSize = getSubGroupSize(); + auto subGroupSize = getThreadsPerWarp_(); auto executionSize = getExecutionSize(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not " @@ -394,7 +394,7 @@ SmallVector DpasEncodingAttr::getContigPerThread() { assert(rank == 2 || rank == 3); SmallVector contigPerThread(rank, 1); - unsigned threadsPerWarp = getSubGroupSize(); + unsigned threadsPerWarp = getThreadsPerWarp_(); auto instShapeC = getDPASInstShapeC(); // The software vectorization vectorized the value as C array: int a[N] -> int // a[N][threadsPerWarp] @@ -506,7 +506,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const { << "systolicDepth = " << getSystolicDepth() << ", " << "executionSize = " << getExecutionSize() << ", " << "opsPerChan = " << getOpsPerChannel() << ", " - << "threadsPerWarp = " << getSubGroupSize() << ", " + << "threadsPerWarp = " << getThreadsPerWarp_() << ", " << "warpsPerCTA = [" << llvm::ArrayRef(warpsPerCTA) << "], " << "repCluster = [" << repCluster << "], " << "A = [" << rA << "], " diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 72d5f7e291..15fd5b11ea 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion size_t totalElems = elems.size(); auto numElemsPerOperand = product(dpasLayout.getDPASInstShapeC()) / - dpasLayout.getSubGroupSize(); + product(dpasLayout.getThreadsPerWarp()); Type elemTy = this->getTypeConverter()->convertType(srcType.getElementType()); VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index ecd8eb1140..b86d14f368 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -36,7 +36,7 @@ class DotOpDPASConversionHelper { Type i16Ty = type::i16Ty(ctx); Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed); - unsigned threadsPerWarp = layout.getSubGroupSize(); + unsigned threadsPerWarp = product(layout.getThreadsPerWarp()); unsigned opsPerChannel = layout.getOpsPerChannel(); SmallVector shapeC = layout.getDPASInstShapeC(); unsigned elemNumC = product(shapeC) / threadsPerWarp; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index 2160b8f17d..af53157be4 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -168,7 +168,8 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout, sizePerThreads[rank - 2] / repCluster[rank - 2], sizePerThreads[rank - 1] / repCluster[rank - 1]}; - unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1]; + unsigned rowsPerElem = + product(dpasLayout.getThreadsPerWarp()) / instShapeC[1]; unsigned colsPerElem = 1; unsigned repNumber = product(repCluster); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 914e851b70..d2cf09213b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -175,11 +176,12 @@ struct DpasOperandPattern final : OpRewritePattern { // We want to transpose matrices of N*threads_per_warpxthreads_per_warp // shape. + unsigned threadsPerWarp = product(encoding.getThreadsPerWarp()); if ( // X axis condition - encoding.getExecutionSize() != encoding.getSubGroupSize() || + encoding.getExecutionSize() != threadsPerWarp || // Y axis conditions (encoding.getRepeatCount() * encoding.getRepCluster()[0]) % - encoding.getSubGroupSize() != + threadsPerWarp != 0) return failure(); From ecfe00609961254bd42636232294f4be3f405c3b Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Wed, 8 Jan 2025 09:10:08 +0800 Subject: [PATCH 3/8] Update third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td Co-authored-by: Whitney Tsang --- .../include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 9208769237..c13eba9d63 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -47,7 +47,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count) +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M = repeat count) t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v From 17899d3855656fed8bbe1ef746c50b6d679f70ed Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Wed, 8 Jan 2025 09:10:18 +0800 Subject: [PATCH 4/8] Update third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td Co-authored-by: Whitney Tsang --- .../include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index c13eba9d63..f1b966606d 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -80,7 +80,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M=repeat count) +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M = repeat count) t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v From 24aec77a4ac2591d8c24eb251988dd183ae16605 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Tue, 7 Jan 2025 20:16:32 -0500 Subject: [PATCH 5/8] Update third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td --- .../include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index f1b966606d..902fff7aec 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -115,7 +115,7 @@ t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | -t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | M = 8 (repeat count) +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | M = 8 (M = repeat count) t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 v From be2a97b70d23cb84c8d349ec5059ff11c29fb7d2 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Tue, 7 Jan 2025 20:16:41 -0500 Subject: [PATCH 6/8] Update third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td --- .../include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 902fff7aec..0ee3f09567 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -96,7 +96,7 @@ t0 t1 t2 t3 t4 t5 t6 t7 ^ t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 | t8 t9 t10 t11 t12 t13 t14 t15 | -t0 t1 t2 t3 t4 t5 t6 t7 | M = 8 (repeat count) +t0 t1 t2 t3 t4 t5 t6 t7 | M = 8 (M = repeat count) t8 t9 t10 t11 t12 t13 t14 t15 | t0 t1 t2 t3 t4 t5 t6 t7 | t8 t9 t10 t11 t12 t13 t14 t15 v From 7f20255a197b7b60c1bd9620f12addf2cbbe6d27 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 8 Jan 2025 01:30:23 +0000 Subject: [PATCH 7/8] Fix merge --- lib/Dialect/TritonGPU/IR/Ops.cpp | 2 +- .../intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index dc60f87595..4e62c4d58c 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -427,7 +427,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(), intel::DpasEncodingAttr::getOpsPerChannel(elemType), dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), - dpasEncoding.getSubGroupSize()); + dpasEncoding.getThreadsPerWarp_()); newVEncoding = DotOperandEncodingAttr::get( ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel()); } else { diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 85b9dce7bb..125b01f9d1 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -304,7 +304,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto opEncoding = ttg::intel::DpasEncodingAttr::get( ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(), - dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize()); + dpasEnc.getRepCluster(), dpasEnc.getThreadsPerWarp_()); auto newOpEncoding = ttg::DotOperandEncodingAttr::get( ctx, unsigned(opIdx), opEncoding, opEncoding.getOpsPerChannel()); @@ -362,7 +362,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get( ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(), - dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize()); + dpasEnc.getRepCluster(), dpasEnc.getThreadsPerWarp_()); auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get(ctx, unsigned(opIdx), retDpasEncoding, retDpasEncoding.getOpsPerChannel()); From e6ebfe93695bfbd5b55863fe5bfe738fb41a9ac4 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 8 Jan 2025 03:20:43 +0000 Subject: [PATCH 8/8] address review comments --- lib/Dialect/TritonGPU/IR/Ops.cpp | 2 +- .../TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td | 4 ++-- .../lib/Dialect/TritonIntelGPU/IR/Dialect.cpp | 14 +++++++------- .../TritonIntelGPUTransforms/AccelerateMatmul.cpp | 6 ++++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 4e62c4d58c..eea9738c6c 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -427,7 +427,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(), intel::DpasEncodingAttr::getOpsPerChannel(elemType), dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), - dpasEncoding.getThreadsPerWarp_()); + product(dpasEncoding.getThreadsPerWarp())); newVEncoding = DotOperandEncodingAttr::get( ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel()); } else { diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index e42f63f46b..26a31e5f99 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -28,7 +28,7 @@ The encoding is characterized by parameters: 1 for 32 bit scalar type of A/B operands of DPAS instruction. - `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2. - `repCluster` indicates the cluster size of the repetitions of the DPAS tile. - - `threadsPerWarp_` AKA threadsPerWarp, use the name threadsPerWarp_ to avoid conflicting + - `threadsPerWarp__` AKA threadsPerWarp, use the name threadsPerWarp__ to avoid conflicting with the `getThreadsPerWarp` in interface DistributedLayout. Currently only 16 is supported. The values of the matrix is distributed across the threads in the subgroup as row-major order. @@ -186,7 +186,7 @@ The semantic of this `tt.dot` includes GEMM tiling configuration as: "unsigned":$opsPerChannel, ArrayRefParameter<"unsigned">:$warpsPerCTA__, ArrayRefParameter<"unsigned">:$repCluster, - "unsigned":$threadsPerWarp_ + "unsigned":$threadsPerWarp__ ); let extraClassDeclaration = extraDistributedDeclaration # [{ diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index e316120e8c..de2f019fd8 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -134,7 +134,7 @@ SmallVector DpasEncodingAttr::getShapeC() const { SmallVector DpasEncodingAttr::getSizePerThread() const { size_t rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); - unsigned threadsPerWarp = getThreadsPerWarp_(); + unsigned threadsPerWarp = getThreadsPerWarp__(); SmallVector shapeC = getDPASInstShapeC(); unsigned elemsNum = product(shapeC); unsigned elemsPerThread = elemsNum / threadsPerWarp; @@ -263,7 +263,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, mlir::Type eltTy, int kWidth, OpIdx opIdx) const { SmallVector shapePerCTA = getShapePerCTA(*this, shape); SmallVector rep = getDPASRepetitions(shapePerCTA, opIdx); - unsigned threadsPerWar = getThreadsPerWarp_(); + unsigned threadsPerWar = getThreadsPerWarp__(); size_t rank = shape.size(); switch (opIdx) { @@ -302,7 +302,7 @@ SmallVector DpasEncodingAttr::getThreadsPerWarp() const { size_t rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); unsigned executionSize = getExecutionSize(); - unsigned subGroupSize = getThreadsPerWarp_(); + unsigned subGroupSize = getThreadsPerWarp__(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be " "smaller than the execution size"); @@ -321,7 +321,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const { switch (opIdx) { case OpIdx::OperandA: { SmallVector shapeA = getDPASInstShapeA(); - unsigned subGroupSize = getThreadsPerWarp_(); + unsigned subGroupSize = getThreadsPerWarp__(); unsigned opsPerChannel = getOpsPerChannel(); // pack the value to i16 for scalar bit width <=16. @@ -339,7 +339,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const { } break; case OpIdx::OperandB: { SmallVector shapeB = getShapeB(); - unsigned subGroupSize = getThreadsPerWarp_(); + unsigned subGroupSize = getThreadsPerWarp__(); unsigned executionSize = getExecutionSize(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not " @@ -359,7 +359,7 @@ SmallVector DpasEncodingAttr::getContigPerThread() const { assert(rank == 2 || rank == 3); SmallVector contigPerThread(rank, 1); - unsigned threadsPerWarp = getThreadsPerWarp_(); + unsigned threadsPerWarp = getThreadsPerWarp__(); SmallVector instShapeC = getDPASInstShapeC(); // The software vectorization vectorized the value as C array: int a[N] -> // int a[N][threadsPerWarp] @@ -494,7 +494,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const { << "systolicDepth = " << getSystolicDepth() << ", " << "executionSize = " << getExecutionSize() << ", " << "opsPerChan = " << getOpsPerChannel() << ", " - << "threadsPerWarp = " << getThreadsPerWarp_() << ", " + << "threadsPerWarp = " << getThreadsPerWarp__() << ", " << "warpsPerCTA = [" << llvm::ArrayRef(warpsPerCTA) << "], " << "repCluster = [" << repCluster << "], " << "A = [" << rA << "], " << "B = [" << rB << "], " << "C = [" << rC << "]" << "}>"; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 125b01f9d1..43b40d6313 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -304,7 +304,8 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto opEncoding = ttg::intel::DpasEncodingAttr::get( ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(), - dpasEnc.getRepCluster(), dpasEnc.getThreadsPerWarp_()); + dpasEnc.getRepCluster(), + product(dpasEnc.getThreadsPerWarp())); auto newOpEncoding = ttg::DotOperandEncodingAttr::get( ctx, unsigned(opIdx), opEncoding, opEncoding.getOpsPerChannel()); @@ -362,7 +363,8 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get( ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(), - dpasEnc.getRepCluster(), dpasEnc.getThreadsPerWarp_()); + dpasEnc.getRepCluster(), + product(dpasEnc.getThreadsPerWarp())); auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get(ctx, unsigned(opIdx), retDpasEncoding, retDpasEncoding.getOpsPerChannel());