diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index dc60f87595..eea9738c6c 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -427,7 +427,7 @@ LogicalResult UpcastMXFPOp::inferReturnTypes( dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(), intel::DpasEncodingAttr::getOpsPerChannel(elemType), dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(), - dpasEncoding.getSubGroupSize()); + product(dpasEncoding.getThreadsPerWarp())); newVEncoding = DotOperandEncodingAttr::get( ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel()); } else { diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 0290106b0a..26a31e5f99 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -14,7 +14,8 @@ def DpasEncodingAttr : DistributedEncoding<"DpasEncoding", "intel_dpas_encoding" let mnemonic = "dpas"; let description = [{ -An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation. +An encoding for the tensors distributed across the threads for the C and D operands of XMX tensor core operation +and its corresponding A and B operands layout with the DPAS encoding as parent. The XMX tensor core operation is defined for matrix matmul as: D=A*B+C The shape of the of XMX tensor core operation is defined by systolic depth, repeat count, execution size and operations per channel. @@ -22,44 +23,159 @@ The encoding is characterized by parameters: - `repeatCount` which shall be in the range [1, 8] - `systolicDepth` For PVC/ATSM, the size is 8. - `executionSize` For PVC, the size is 16. For ATSM, the size is 8. - - `opsPerChannel` 4 for 8 bit scalar type, 2 for 16 bit scalar type, 1 for 32 bit scalar type. - - `warpsPerCTA` - - `sugGroupSize` valid sub group size is 8/16/32 - - -The layout example repeat_count=8, systolic_depth=8, execution_size=16 and operands_per_chan=2 for warp size 32. -For A operand: - systolic depth = 8 -<-------------------------------------------------------------------------------------------------> -opsPerChan=2 -<---------> -t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 ^ -t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 | -t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 | -t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 | repeat count <= 8 -t0 ... t0 t1 ... t1 t2 ... t2 t3 ... t3 t4 ... t4 t5 ... t5 t6 ... t6 t7 ... t7 | -t8 ... t8 t9 ... t9 t10 ... t10 t11 ... t11 t12 ... t12 t13 ... t13 t14 ... t14 t15 ... t15 | -t16 ... t16 t17 ... t17 t18 ... t18 t19 ... t19 t20 ... t20 t21 ... t21 t22 ... t22 t23 ... t23 | -t24 ... t24 t25 ... t25 t26 ... t26 t27 ... t27 t28 ... t28 t29 ... t29 t30 ... t30 t31 ... t31 v - -For B operand: - execution size = 16 -<-------------------------------------------------------------> -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ ^ -. . . . . . . . . . . . . . . . | opsPerChan=2| -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v | -t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | -. . . . . . . . . . . . . . . . | -t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | systolic depth = 8 -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | -. . . . . . . . . . . . . . . . | -t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | -t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | -. . . . . . . . . . . . . . . . | -t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v - -This pattern repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks -along the row (resp. col) dimension. + - `opsPerChannel` 4 for 8 bit scalar type of A/B operands of DPAS instruction, + 2 for 16 bit scalar type of A/B operands of DPAS instruction, + 1 for 32 bit scalar type of A/B operands of DPAS instruction. + - `warpsPerCTA` indicates the distribution of the warps in the block. The order is [1, 0] for rank 2. + - `repCluster` indicates the cluster size of the repetitions of the DPAS tile. + - `threadsPerWarp__` AKA threadsPerWarp, use the name threadsPerWarp__ to avoid conflicting + with the `getThreadsPerWarp` in interface DistributedLayout. Currently only 16 is supported. + +The values of the matrix is distributed across the threads in the subgroup as row-major order. + - If the column size of the matrix is equal to the number of threads in the subgroup, one scalar represents one row of the matrix in register. + - If the column size of the matrix is less than the number of threads in the subgroup, one scalar represents multiple rows of the matrix in register. + - If the column size of the matrix is larger than the number of the threads in the subgroup, one scalar represents partial row of the matrix in register. + +Example 1, the column size of the matrix is 16 and the number of threads in the subgroup is 16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=2 and threadsPerWarp=16. + +The layout for A operand: + K = 16 (K = systolic depth * opsPerChan) +<----------------------------------------------------------------------------> + +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M = repeat count) +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v + +The layout for B operand: + N = 16 (N = execution size) +<----------------------------------------------------------------------------> + +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | K = 16 (K = systolic depth * opsPerChan) +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v + +The layout for C operand and result D: + N = 16 (N = execution size) +<----------------------------------------------------------------------------> +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (M = repeat count) +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 v + +Example 2, the column size of the matrix is 8 and the number of threads in the subgroup is 16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=1 and threadsPerWarp=16. + +The layout for A operand: + K = 8 (K = systolic depth * opsPerChan) +<----------------------------------------> + +t0 t1 t2 t3 t4 t5 t6 t7 ^ +t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 | +t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 | M = 8 (M = repeat count) +t8 t9 t10 t11 t12 t13 t14 t15 | +t0 t1 t2 t3 t4 t5 t6 t7 | +t8 t9 t10 t11 t12 t13 t14 t15 v + +The layouts for B operand is like the one of opsPerChan=2 but the K size is 8. +The layouts for C and D operands are same as the one of opsPerChan=2. + +Example 3, the column size of the matrix is 32 and the number of threads in the subgroup is 16. +The DPAS encoding of repeatCount=8, systolicDepth=8, executionSize=16, opsPerChannel=4 and threadsPerWarp=16. + +The layout for A operand: + K = 32 (K = systolic depth * opsPerChan) +<-----------------------------------------------------------------------------------------------------------------------------------> + +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 ^ +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | M = 8 (M = repeat count) +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 | +t0 t0 t1 t1 t2 t2 t3 t3 t4 t4 t5 t5 t6 t6 t7 t7 t8 t8 t9 t9 t10 t10 t11 t11 t12 t12 t13 t13 t14 t14 t15 t15 v + +The layouts for B operand is like the one of opsPerChan=2 but the K size is 32. +The layouts for C and D operands are same as the one of opsPerChan=2. + +The patterns (illustrated above) repeats every warpsPerTile[0] (resp. warpsPerTile[1]) blocks +along the row (resp. col) dimension. And the repetitions are clustered of the size of repCluster to optimize the memory accessing. + +Suppose we have a `tt.dot` operation of the block size [64, 128] = [64, 32] * [32, 128] of f16/bf16. And its input tensor layout is defined as follows: +``` +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 2]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#dpas, kWidth=2}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> + +%d = tt.dot %a, %b, %c : tensor<64x32xf16, #dot_operand_a> * tensor<32x128xf16, #dot_operand_b> -> tensor<64x128xf32, #dpas> +``` +The semantic of this `tt.dot` includes GEMM tiling configuration as: + + warp[:0] warp[:1] warp[:0] warp[:1] + |----^----|----^----|----^----|----^----| + repCluster[1] + <---------> + ┌────┬────┬────┬────┬────┬────┬────┬────┐ + │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│ + │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│ + warpPerCTA = [[W0, W1], ├────┼────┼────┼────┼────┼────┼────┼────┤ + [W2, W3]] │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│ + │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│ + └────┴────┴────┴────┴────┴────┴────┴────┘ + + + - ^ ┌────┬────┐ ┌────┬────┬────┬────┬────┬────┬────┬────┐ + | | │W0R0│W0R2│ │W0R0│W0R1│W1R0│W1R1│W0R4│W0R5│W1R4│W1R5│ + | | │W1R0│W1R2│ │ │ │ │ │ │ │ │ │ + warp[0:] < repCluster[0] | ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | | │W0R1│W0R3│ │W0R2│W0R3│W1R2│W1R3│W0R6│W0R7│W1R6│W1R7│ + | | │W1R1│W1R3│ │ │ │ │ │ │ │ │ │ + - v ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W2R0│W2R2│ │W2R0│W2R1│W3R0│W3R1│W2R4│W2R5│W3R4│W3R5│ + | │W3R0│W3R2│ │ │ │ │ │ │ │ │ │ + warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W2R1│W2R1│ │W2R2│W2R3│W3R2│W3R3│W2R6│W2R7│W3R6│W3R7│ + | │W3R1│W3R1│ │ │ │ │ │ │ │ │ │ + - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W0R4│W0R6│ │W0R8│W0R9│W1R8│W1R9│W0 │W0 │W1 │W1 │ + | │W1R4│W1R6│ │ │ │ │ │R12 │R13 │R12 │R13 │ + warp[0:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W0R5│W0R7│ │W0 │W0 │W1 │W1 │W0 │W0 │W1 │W1 │ + | │W1R5│W1R7│ │R10 │R11 │R10 │R11 │R14 │R15 │R14 │R15 │ + - ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W2R4│W2R6│ │W2R8│W2R9│W3R8│W3R8│W2 │W2 │W3 │W3 │ + | │W3R4│W3R6│ │ │ │ │ │R12 │R13 │R12 │R13 │ + warp[1:] < ├────┼────┤ ├────┼────┼────┼────┼────┼────┼────┼────┤ + | │W2R5│W2R7│ │W2 │W2 │W3 │W3 │W2 │W2 │W3 │W3 │ + | │W3R5│W3R7│ │R10 │R11 │R10 │R10 │R14 │R15 │R14 │R15 │ + - └────┴────┘ └────┴────┴────┴────┴────┴────┴────┴────┘ + + }]; let parameters = ( @@ -70,7 +186,7 @@ along the row (resp. col) dimension. "unsigned":$opsPerChannel, ArrayRefParameter<"unsigned">:$warpsPerCTA__, ArrayRefParameter<"unsigned">:$repCluster, - "unsigned":$subGroupSize + "unsigned":$threadsPerWarp__ ); let extraClassDeclaration = extraDistributedDeclaration # [{ diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 26591d8dab..de2f019fd8 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -134,7 +134,7 @@ SmallVector DpasEncodingAttr::getShapeC() const { SmallVector DpasEncodingAttr::getSizePerThread() const { size_t rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); - unsigned threadsPerWarp = getSubGroupSize(); + unsigned threadsPerWarp = getThreadsPerWarp__(); SmallVector shapeC = getDPASInstShapeC(); unsigned elemsNum = product(shapeC); unsigned elemsPerThread = elemsNum / threadsPerWarp; @@ -263,7 +263,7 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperand( ArrayRef shape, mlir::Type eltTy, int kWidth, OpIdx opIdx) const { SmallVector shapePerCTA = getShapePerCTA(*this, shape); SmallVector rep = getDPASRepetitions(shapePerCTA, opIdx); - unsigned threadsPerWar = getSubGroupSize(); + unsigned threadsPerWar = getThreadsPerWarp__(); size_t rank = shape.size(); switch (opIdx) { @@ -302,7 +302,7 @@ SmallVector DpasEncodingAttr::getThreadsPerWarp() const { size_t rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); unsigned executionSize = getExecutionSize(); - unsigned subGroupSize = getSubGroupSize(); + unsigned subGroupSize = getThreadsPerWarp__(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be " "smaller than the execution size"); @@ -321,7 +321,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const { switch (opIdx) { case OpIdx::OperandA: { SmallVector shapeA = getDPASInstShapeA(); - unsigned subGroupSize = getSubGroupSize(); + unsigned subGroupSize = getThreadsPerWarp__(); unsigned opsPerChannel = getOpsPerChannel(); // pack the value to i16 for scalar bit width <=16. @@ -339,7 +339,7 @@ DpasEncodingAttr::getSizePerThreadForOperand(int kWidth, OpIdx opIdx) const { } break; case OpIdx::OperandB: { SmallVector shapeB = getShapeB(); - unsigned subGroupSize = getSubGroupSize(); + unsigned subGroupSize = getThreadsPerWarp__(); unsigned executionSize = getExecutionSize(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not " @@ -359,7 +359,7 @@ SmallVector DpasEncodingAttr::getContigPerThread() const { assert(rank == 2 || rank == 3); SmallVector contigPerThread(rank, 1); - unsigned threadsPerWarp = getSubGroupSize(); + unsigned threadsPerWarp = getThreadsPerWarp__(); SmallVector instShapeC = getDPASInstShapeC(); // The software vectorization vectorized the value as C array: int a[N] -> // int a[N][threadsPerWarp] @@ -494,7 +494,7 @@ void DpasEncodingAttr::print(AsmPrinter &printer) const { << "systolicDepth = " << getSystolicDepth() << ", " << "executionSize = " << getExecutionSize() << ", " << "opsPerChan = " << getOpsPerChannel() << ", " - << "threadsPerWarp = " << getSubGroupSize() << ", " + << "threadsPerWarp = " << getThreadsPerWarp__() << ", " << "warpsPerCTA = [" << llvm::ArrayRef(warpsPerCTA) << "], " << "repCluster = [" << repCluster << "], " << "A = [" << rA << "], " << "B = [" << rB << "], " << "C = [" << rC << "]" << "}>"; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index c5cd7d950a..e76e59280a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -334,7 +334,7 @@ struct ConvertLayoutOpConversion size_t totalElems = elems.size(); auto numElemsPerOperand = product(dpasLayout.getDPASInstShapeC()) / - dpasLayout.getSubGroupSize(); + product(dpasLayout.getThreadsPerWarp()); Type elemTy = this->getTypeConverter()->convertType(srcType.getElementType()); VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index 94c58615c6..e0ae325e55 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -37,7 +37,7 @@ class DotOpDPASConversionHelper { Type i16Ty = type::i16Ty(ctx); Type s32Ty = IntegerType::get(ctx, 32, IntegerType::Signed); - unsigned threadsPerWarp = layout.getSubGroupSize(); + unsigned threadsPerWarp = product(layout.getThreadsPerWarp()); unsigned opsPerChannel = layout.getOpsPerChannel(); SmallVector shapeC = layout.getDPASInstShapeC(); unsigned elemNumC = product(shapeC) / threadsPerWarp; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index b2b1d0e9a4..f641f9fa97 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -120,7 +120,8 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout, sizePerThreads[rank - 2] / repCluster[rank - 2], sizePerThreads[rank - 1] / repCluster[rank - 1]}; - unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1]; + unsigned rowsPerElem = + product(dpasLayout.getThreadsPerWarp()) / instShapeC[1]; unsigned colsPerElem = 1; unsigned repNumber = product(repCluster); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 85b9dce7bb..43b40d6313 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -304,7 +304,8 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto opEncoding = ttg::intel::DpasEncodingAttr::get( ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(), - dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize()); + dpasEnc.getRepCluster(), + product(dpasEnc.getThreadsPerWarp())); auto newOpEncoding = ttg::DotOperandEncodingAttr::get( ctx, unsigned(opIdx), opEncoding, opEncoding.getOpsPerChannel()); @@ -362,7 +363,8 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto retDpasEncoding = ttg::intel::DpasEncodingAttr::get( ctx, dpasEnc.getRepeatCount(), dpasEnc.getSystolicDepth(), dpasEnc.getExecutionSize(), opsPerChannel, dpasEnc.getWarpsPerCTA(), - dpasEnc.getRepCluster(), dpasEnc.getSubGroupSize()); + dpasEnc.getRepCluster(), + product(dpasEnc.getThreadsPerWarp())); auto retDotOpEncoding = ttg::DotOperandEncodingAttr::get(ctx, unsigned(opIdx), retDpasEncoding, retDpasEncoding.getOpsPerChannel()); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 4f5ce5a238..b7bd652a40 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -237,11 +238,12 @@ struct DpasOperandPattern final : OpRewritePattern { // We want to transpose matrices of N*threads_per_warpxthreads_per_warp // shape. + unsigned threadsPerWarp = product(encoding.getThreadsPerWarp()); if ( // X axis condition - encoding.getExecutionSize() != encoding.getSubGroupSize() || + encoding.getExecutionSize() != threadsPerWarp || // Y axis conditions (encoding.getRepeatCount() * encoding.getRepCluster()[0]) % - encoding.getSubGroupSize() != + threadsPerWarp != 0) return failure();