diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir new file mode 100644 index 0000000000..8a9b6f184a --- /dev/null +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -0,0 +1,129 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --tritonintelgpu-optimize-block-io-encoding | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + +// ----- + +// COM: Dot operand A transpose currently not supported by subgroup 2d block io encoding +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 1], warpsPerCTA = [2, 16], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK: #mma1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +// CHECK-NOT: #mma2 +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c1_i64, %c1024_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + // CHECK: {{.*}} = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked2> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>> -> tensor<256x256xf32, #mma1> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index d225494f75..2c69390029 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -280,6 +280,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) + intel.passes.ttgpuir.add_optimize_block_load_encoding(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt)) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index c20224aaee..91625ba0be 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -409,4 +409,15 @@ def TritonIntelGPUReduceVariableLiveness "mlir::scf::SCFDialect", "mlir::arith::ArithDialect"]; } + +def TritonIntelGPUOptimizeBlockIOEncodingPass : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> { + let summary = "Set encodings on candidates for Subgroup 2D Block IO ops"; + + let description = [{ + Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a ConvertLayout op to the existing encoding replaces the result of the LoadOp. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::gpu::intel::TritonIntelGPUDialect", "mlir::triton::TritonDialect"]; +} + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 9ab7baaa71..c356f07f20 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -33,6 +33,15 @@ Attribute inferSrcEncoding(Operation *op, Attribute encoding); // Retuns true if the operation is an expensive load or store operation. bool isExpensiveLoadOrStore(Operation *op); +// Returns true if the conversion between tensor types should be a no-op. Will +// be removed once layout conversion for BlockIO types is lifted from +// LoadStoreOpToLLVM.cpp +bool isBlockIONoOpConversion(RankedTensorType srcType, + RankedTensorType dstType); + +// Returns true if the tensor type has a subgroup 2d block io encoding +bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType); + // Returns true if the tensor type has a dot dpas encoding. bool hasDotDpasEncoding(RankedTensorType tensorType); diff --git a/third_party/intel/lib/Analysis/Allocation.cpp b/third_party/intel/lib/Analysis/Allocation.cpp index 8c9cfe5147..e6cd3df5d2 100644 --- a/third_party/intel/lib/Analysis/Allocation.cpp +++ b/third_party/intel/lib/Analysis/Allocation.cpp @@ -1,5 +1,6 @@ #include "intel/include/Analysis/Allocation.h" #include "intel/include/Analysis/Utility.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" // isBlockIONoOpConversion #include "triton/Dialect/Triton/IR/Utility.h" #include "llvm/ADT/TypeSwitch.h" @@ -11,6 +12,9 @@ constexpr unsigned invalidSize = -1; unsigned allocationAnalysisScratchSizeFn(gpu::ConvertLayoutOp convertLayout) { RankedTensorType srcTy = convertLayout.getSrc().getType(); RankedTensorType dstTy = convertLayout.getResult().getType(); + + if (gpu::intel::isBlockIONoOpConversion(srcTy, dstTy)) + return 0; if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy)) return 0; if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy)) { diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 0feb202b49..173cbdb29a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -3,6 +3,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "intel/include/Analysis/Utility.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" namespace mlir::triton::gpu { namespace { @@ -24,9 +25,18 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion ConversionPatternRewriter &rewriter) const override { MLIRContext *ctx = op.getContext(); - auto srcTy = op.getSrc().getType(); + RankedTensorType srcTy = op.getSrc().getType(); auto dstTy = op.getType(); + if (auto dstTensorTy = cast(dstTy)) { + if (intel::isBlockIONoOpConversion(srcTy, dstTensorTy)) { + // TODO: replace this with proper conversion once conversion is removed + // from LoadStoreOpToLLVM. + rewriter.replaceOp(op, op.getSrc()); + return success(); + } + } + LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); LinearLayout srcLayout = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 32be8bd222..7fcc5eb253 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -302,7 +302,8 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { // Only lower loadOp with dpas layout encoding. auto tensorTy = cast(op.getType()); - return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy); + return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy) || + hasSubgroup2DBlockEncoding(tensorTy); } template < @@ -342,6 +343,15 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { : getDotEncoding(tensorTy).value().getParent()); } + static RankedTensorType getDpasTypeFromCVTOp(Value opResult) { + for (OpOperand user : opResult.getUsers()) { + if (auto cvt = dyn_cast(user.getOwner())) { + return cast(cvt.getResult().getType()); + } + } + llvm_unreachable("expected to find a cvt op with dpas layout"); + } + // Returns the pitch (stride in bytes) of \p ptr. Value getPitch(ConversionPatternRewriter &rewriter, Value ptr, const std::map, Value> &ptrs, @@ -1416,12 +1426,20 @@ struct LoadOpConversion auto tensorType = cast(resultType); const bool memoryRowMajor = isMemoryRowMajor(op); - DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); + + auto dpasTensorType = hasSubgroup2DBlockEncoding(tensorType) + ? getDpasTypeFromCVTOp(op.getResult()) + : tensorType; + DpasEncodingAttr dpasLayout = getDpasLayout(dpasTensorType); + + DpasEncodingAttr::OpIdx opIdx = getOpIdx(dpasTensorType); LLVM_DEBUG(llvm::dbgs() << "Tensor type for op " << int(opIdx) << ": " << tensorType << "\n"); Attribute encoding = tensorType.getEncoding(); + // TODO: this gives us the linear layour corresponding + // to the subgroup 2d block encoding, not the dpas encoding... std::optional llEncoding = cast(encoding).toLinearLayout( tensorType.getShape()); @@ -1440,14 +1458,21 @@ struct LoadOpConversion Type eltTy = tensorType.getElementType(); unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( - cast(encoding), tensorType.getShape(), - memoryRowMajor, elemSizeInBits / 8, rewriter.getContext()); - unsigned tileHeight = tileParams[0]; - const unsigned tileWidth = tileParams[1]; - const unsigned vBlocks = tileParams[2]; + auto getTileParams = [&]() -> std::tuple { + if (hasSubgroup2DBlockEncoding(tensorType)) { + auto encoding = + cast(tensorType.getEncoding()); + auto shape = encoding.getInstrShape(); + return std::make_tuple(shape[0], shape[1], encoding.getNumBlocks()); + } else { + auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( + cast(encoding), tensorType.getShape(), + memoryRowMajor, elemSizeInBits / 8, rewriter.getContext()); + return std::make_tuple(tileParams[0], tileParams[1], tileParams[2]); + } + }; + auto [tileHeight, tileWidth, vBlocks] = getTileParams(); - DpasEncodingAttr dpasLayout = getDpasLayout(tensorType); const ArrayRef tensorShape = tensorType.getShape(); unsigned numElems = getTotalElemsPerThread(resultType); SmallVector numReps = @@ -2186,7 +2211,9 @@ struct LoadOpConversion } } - Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Type llvmResultStructTy = typeConverter->convertType(dpasTensorType); + LLVM_DEBUG(llvm::dbgs() << "Packing load result in struct " + << llvmResultStructTy << "\n"); Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); @@ -2209,11 +2236,22 @@ struct LoadOpConversion Value mask = op.getMask(); Value llMask = adaptor.getMask(); + auto opType = op.getType(); + // TODO: Override the OpType since conversion is still happening during Load + // lowering. Once we materialize ConvertLayoutOp this can be removed. + auto tensorTy = dyn_cast(opType); + if (tensorTy && hasSubgroup2DBlockEncoding(tensorTy)) + opType = getDpasTypeFromCVTOp(op.getResult()); + // Determine the vectorization size - Type valueElemTy = - typeConverter->convertType(getElementTypeOrSelf(op.getType())); - unsigned numElems = getTotalElemsPerThread(op.getType()); + Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(opType)); + unsigned numElems = getTotalElemsPerThread(opType); unsigned vec = getVectorSize(ptr); + LLVM_DEBUG({ + llvm::dbgs() << "Vectorization for gather load:\n"; + llvm::dbgs() << "\t" << valueElemTy << " [" << numElems << "]\n"; + llvm::dbgs() << "\tvector size = " << vec << " for " << ptr << "\n"; + }); if (llMask) vec = std::min(vec, getMaskAlignment(mask)); @@ -2223,9 +2261,11 @@ struct LoadOpConversion if (isTensorPointerType(ptr.getType())) { // fallback to gather load. - auto tensorType = cast(op.getType()); + // make sure we use the modified opType from above, "seeing through" any + // post-subgroup 2d block encoding CVT. + auto blockPtrTensorType = cast(opType); std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr( - loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, + loc, adaptor.getPtr(), blockPtrTensorType, valueElemTy, rewriter, op.getBoundaryCheck(), op.getPadding()); } else { Value other = op.getOther(); @@ -2370,7 +2410,7 @@ struct LoadOpConversion } } // end vec - Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Type llvmResultStructTy = typeConverter->convertType(opType); Value resultStruct = packLLElements(loc, typeConverter, loadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index b8cb96cfa0..bb32041127 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeBlockIOEncoding.cpp OptimizeDotOperands.cpp OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp new file mode 100644 index 0000000000..db912b34ba --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -0,0 +1,344 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/ADT/PriorityWorklist.h" + +namespace ttg = mlir::triton::gpu; +namespace ttgi = mlir::triton::gpu::intel; + +namespace mlir { +namespace triton { +namespace gpu::intel { + +#define DEBUG_TYPE "tritongpu-optimize-block-encoding" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +SmallVector getTiedArgs(Operation *op, int resultIdx) { + if (auto forOp = dyn_cast(op)) { + auto iterArg = forOp.getRegionIterArg(resultIdx); + auto result = forOp.getResult(resultIdx); + auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx); + auto initVal = forOp.getInitArgs()[resultIdx]; + return {iterArg, result, yieldVal, initVal}; + } else if (auto whileOp = dyn_cast(op)) { + auto iterArg = whileOp.getBeforeArguments()[resultIdx]; + auto result = whileOp.getResults()[resultIdx]; + auto yieldVal = + whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx); + auto initVal = whileOp.getOperands()[resultIdx]; + return {iterArg, result, iterArg, initVal}; + } else if (auto ifOp = dyn_cast(op)) { + SmallVector values; + for (auto &block : ifOp.getThenRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + for (auto &block : ifOp.getElseRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + values.push_back(ifOp->getResults()[resultIdx]); + return values; + } + return {}; +} + +Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +Type getNewPointerType(Type type, Attribute encoding) { + assert(isa(type) && "expected a ptr type!"); + auto oldPointerType = cast(type); + return PointerType::get(getNewType(oldPointerType.getPointeeType(), encoding), + oldPointerType.getAddressSpace()); +} + +struct EncodingInfo { + Attribute desiredEncoding; + bool requiresConvert = false; + + bool operator==(const EncodingInfo &other) const { + return desiredEncoding == other.desiredEncoding && + requiresConvert == other.requiresConvert; + } +}; + +/** + * The algorithm here takes inspiration from + * TritonNVIDIAGPU::OptimizeDescriptorEncoding. The idea is to iterate the + * def-use chain in both directions starting from the Load Op. We store the + * values that need to be updated along with the new encoding in the + * `valueToEncodingInfo` MapVector. After all value/encoding pairs have been + * determined, we update the encoding for each value, adding aa conversion to + * the existing Load Op result layout for users of the load. + */ +void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { + auto loadOp = cast(op); + auto loadPtrType = cast(loadOp->getOperand(0).getType()); + auto addressSpace = loadPtrType.getAddressSpace(); + + llvm::MapVector, EncodingInfo> valueToEncodingInfo; + llvm::PriorityWorklist> worklist; + + auto updateEncoding = [&](ArrayRef ptrValues, EncodingInfo info) { + for (auto value : ptrValues) { + bool requiresConvert = llvm::any_of( + value.getUsers(), [](auto user) { return isa(user); }); + info.requiresConvert = requiresConvert; + + auto typedVal = cast>(value); + auto itr = valueToEncodingInfo.find(typedVal); + if (itr == valueToEncodingInfo.end()) { + LLVM_DEBUG(DBGS() << "Add encoding " << info.desiredEncoding + << " for value " << typedVal << "\n"); + valueToEncodingInfo[typedVal] = info; + worklist.insert(typedVal); + } else { + LLVM_DEBUG(DBGS() << "Found existing encoding info " + << itr->second.desiredEncoding << " for value " + << typedVal << ". Ensure new encoding " + << info.desiredEncoding << " matches.\n"); + assert(itr->second == info && "already visited encoding info for " + "value, expected them to be equal!"); + continue; + } + } + }; + + worklist.insert(cast>(loadOp->getOperand(0))); + + // 1. Starting from the Load Op, propagate encoding info up and down the + // def-use chain. + while (!worklist.empty()) { + auto crtValue = worklist.pop_back_val(); + + // Propagate to users + for (OpOperand &use : crtValue.getUses()) { + auto op = use.getOwner(); + if (isa(op)) { + auto offset = 3 * isa(op); + auto vals = getTiedArgs(op, use.getOperandNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } else if (isa(op)) { + auto vals = getTiedArgs(op->getParentOp(), use.getOperandNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + + // Propagate to defining ops + if (auto opResult = dyn_cast(crtValue)) { + auto definingOp = opResult.getOwner(); + if (isa(definingOp)) { + auto vals = getTiedArgs(definingOp, opResult.getResultNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } else if (auto blockArg = dyn_cast(crtValue)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + auto offset = isa(parentOp); + auto vals = getTiedArgs(parentOp, blockArg.getArgNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + } + + // 2. Update the type for each value in-place. Add a ConvertLayout Op after + // any loads which require conversion to the existing layout for the loaded + // value. + for (auto &[val, einfo] : valueToEncodingInfo) { + Attribute newEncoding = einfo.desiredEncoding; + LLVM_DEBUG(DBGS() << "Rewrite encoding to " << newEncoding << " for value " + << val << "\n"); + + PointerType oldType = val.getType(); + auto oldTensorTy = cast(oldType.getPointeeType()); + auto newTensorTy = RankedTensorType::get( + oldTensorTy.getShape(), oldTensorTy.getElementType(), newEncoding); + + val.setType(PointerType::get(newTensorTy, oldType.getAddressSpace())); + if (einfo.requiresConvert) { + for (auto user : val.getUsers()) { + if (auto loadOp = dyn_cast(user)) { + + OpBuilder builder(loadOp); + auto oldLoadType = loadOp.getType(); + Value result = loadOp.getResult(); + + builder.setInsertionPointAfter(loadOp); + auto cvt = builder.create(loadOp.getLoc(), + result.getType(), result); + LLVM_DEBUG(DBGS() << "Added convert Op:\n" + << cvt << " after Load Op:\n" + << loadOp << "\n"); + result.setType(newTensorTy); + + result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation()); + } + } + } + } +} + +} // namespace + +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEBLOCKIOENCODINGPASS +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +class TritonIntelGPUOptimizeBlockIOEncodingPass + : public impl::TritonIntelGPUOptimizeBlockIOEncodingPassBase< + TritonIntelGPUOptimizeBlockIOEncodingPass> { + + void getSubgroup2DBlockLayoutForOperand( + Value operand, DpasEncodingAttr dpasLayout, + llvm::MapVector &layoutMap) { + auto isCandidateLoad = [](Value v) -> LoadOp { + // Peel out the original cvt dot_op<..., #blocked> + // and any other potential cvt/trans ops + while (true) { + if (auto cvtOp = v.getDefiningOp()) { + v = cvtOp.getSrc(); + continue; + } + if (auto transOp = v.getDefiningOp()) { + v = transOp.getSrc(); + continue; + } + break; + } + return isa(v.getDefiningOp()) ? cast(v.getDefiningOp()) + : nullptr; + }; + + LoadOp loadOp = isCandidateLoad(operand); + if (!loadOp) + return; + + auto dotOperandType = cast(operand.getType()); + auto layout = ttg::toLinearEncoding(dotOperandType); + auto order = layout.getThreadOrder(); + auto rank = order.size(); + if (rank != 2) { + loadOp.emitWarning( + "Subgroup 2D Block Encoding layouts only support rank 2 operands."); + return; + } + + auto dotOperandEncoding = + cast(dotOperandType.getEncoding()); + // layout width is determined by the DPAS operand encoding width + const int kWidth = dotOperandEncoding.getKWidth(); + + Attribute blockIOAttr = + loadOp->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (!blockIOAttr) + return; + + const bool valueRowMajor = + getOrderForDotOperand(0, rank, /*kContig=*/true) == order; + const bool memoryRowMajor = + blockIOAttr == StringAttr::get(&getContext(), "row_major"); + const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; + LLVM_DEBUG({ + DBGS() << "Original layout: " << dotOperandEncoding << "\n"; + DBGS() << "\tvalueRowMajor = " << valueRowMajor << "\n"; + DBGS() << "\tmemoryRowMajor = " << memoryRowMajor << "\n"; + DBGS() << "\tisTransposeRequired = " << isTransposeRequired << "\n"; + }); + if (dotOperandEncoding.getOpIdx() == 0 && isTransposeRequired) { + LLVM_DEBUG(DBGS() << "Transposed 'A' operand does not yet support " + "Subgroup 2D Block Encoding layout.\n"); + return; + } + + // get the MakeTensorPtr Op for the load + Value ptr = loadOp.getPtr(); + if (!isTensorPointerType(ptr.getType())) { + // TODO: support tensor of pointer loads + LLVM_DEBUG(DBGS() << "Ptr\n" + << ptr << " for Load Op:\n" + << loadOp + << "\nincompatible with Subgroup 2D Block Layout.\n"); + return; + } + MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr); + assert(makeTensorPtrOp && + "expecting a tensor pointer parent to block io load " + "with tensor pointer type"); + + auto oldTensorPtrType = cast(makeTensorPtrOp.getType()); + auto oldTensorType = + cast(oldTensorPtrType.getPointeeType()); + // Note: we need the old layout to get the order for the load, but it is not + // clear the layout will always be Blocked. Is there a better way to get + // this info? + auto oldLayout = cast(oldTensorType.getEncoding()); + + auto CTALayout = getCTALayout(dpasLayout); + const unsigned elemSizeInBits = + oldTensorType.getElementType().getIntOrFloatBitWidth(); + + auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( + cast(dotOperandEncoding), + oldTensorType.getShape(), memoryRowMajor, elemSizeInBits / 8, + &getContext()); + SmallVector instrShape{tileParams[0], tileParams[1]}; + const unsigned vBlocks = tileParams[2]; + + auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get( + &getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape, + tileParams[2], + getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ rank, + /*kContig*/ true), + kWidth, dpasLayout.getThreadsPerWarp()); + + LLVM_DEBUG(DBGS() << "Generated new encoding: " << subgroup2DBlockEncoding + << " for op : " << loadOp << "\n"); + + layoutMap[loadOp] = subgroup2DBlockEncoding; + } + +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + // Step 1. Find all loads which are candidates for conversion to Subgroup 2D + // Block Encoding. To be a candidate load, a load must be consumed by a Dot + // Op and the load operand must be a block ptr (produced by a MakeTensorPtr + // Op). Currently we look for loads with the "block_io" attribute but we + // could consider moving that logic to this pass later. We place the load + // and the candidate encoding into the layout map for propagation in step 2 + llvm::MapVector layoutMap; + m.walk([&](DotOp dotOp) { + auto dotOpType = cast(dotOp.getResult().getType()); + auto dpasLayout = dyn_cast(dotOpType.getEncoding()); + if (!dpasLayout) + return; + + getSubgroup2DBlockLayoutForOperand(dotOp.getA(), dpasLayout, layoutMap); + getSubgroup2DBlockLayoutForOperand(dotOp.getB(), dpasLayout, layoutMap); + }); + + // Step 2. Rewrite MakeTensorPtr to use the new layout and propagate the + // change through the def-use chain, terminating at the Load Op. We add a + // ConvertLayout Op after the Load Op to convert back to the original + // layout. Subgroup2DBlockEncoding layouts will be chosen as anchor layouts + // in RemoveLayoutConversions, and a subsequent run of + // RemoveLayoutConversions after this pass cleans up intermediate layout + // conversions and removes the original Load Op encoding. + for (auto &kv : layoutMap) { + rewriteTensorLayoutsForOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu::intel +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp index ee218d76b1..bcc8cbc0b2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp @@ -30,6 +30,8 @@ class TritonIntelGPUReduceDataDuplicationPass auto srcEncoding = srcType.getEncoding(); if (isa(srcEncoding)) return; + if (isa(srcEncoding)) + return; auto dstDotOp = dyn_cast(dstType.getEncoding()); if (!dstDotOp) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 785920770a..f2aa2e6d7a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -120,9 +120,22 @@ bool isExpensiveLoadOrStore(Operation *op) { if (isSingleValue(base)) return false; - // Loads that use a block pointer are expensive if they cannot be lowered to - // 2D block read operations. Temporarily leverage the - // "ttig.block_io" attribute to filter out inexpensive loads. + if (auto loadOp = dyn_cast(op)) { + // Subgroup2DBlockEncodingAttr loads are expensive, but loads without this + // encoding may still be expensive so we only return true if the encodng + // exists + if (auto tensorTy = dyn_cast(loadOp.getType())) + if (isa(tensorTy.getEncoding())) + return true; + } + + // The block ptr attribute identifies loads that are candidates for subgroup + // 2d block io operations. Loads with these attributes (and without the new + // subgroup 2d block encoding above) should have their layouts replaced with + // the layout from the expensive op (usually a dot op with DPAS encoding). The + // load result is convert to the expensive op layout during LLVM lowering. + // Note: the long term plan is to replace this path with the above subgroup 2d + // block encoding layout. Attribute blockIOAttr = op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); if (blockIOAttr) @@ -140,6 +153,18 @@ bool isExpensiveLoadOrStore(Operation *op) { return false; } +bool isBlockIONoOpConversion(RankedTensorType srcType, + RankedTensorType dstType) { + return hasSubgroup2DBlockEncoding(srcType) && hasDotDpasEncoding(dstType); +} + +bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType) { + if (!tensorType.getEncoding()) + return false; + + return isa(tensorType.getEncoding()); +} + bool hasDotDpasEncoding(RankedTensorType tensorType) { if (!tensorType.getEncoding()) return false; diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index ae485a2c7b..1abab23024 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -115,6 +115,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createTritonIntelGPUReduceDataDuplication); ADD_PASS_WRAPPER_0("add_materialize_block_pointer", gpu::intel::createTritonIntelGPUMaterializeBlockPointer); + ADD_PASS_WRAPPER_0( + "add_optimize_block_load_encoding", + gpu::intel::createTritonIntelGPUOptimizeBlockIOEncodingPass); ADD_PASS_WRAPPER_0("add_optimize_reduction_locality", gpu::intel::createTritonIntelGPUOptimizeReductionLocality); ADD_PASS_WRAPPER_0("add_reduce_variable_liveness",