diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir new file mode 100644 index 0000000000..beb421548e --- /dev/null +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -0,0 +1,181 @@ +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonintelgpu-optimize-block-io-encoding | FileCheck %s + +// COM: test complete example +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + + // CHECK: %[[MAKE_TENSOR_PTR_A:.*]] = tt.make_tensor_ptr {{.*}} : > + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%c256_i32, %c0_i32] {order = array} : > + // CHECK: %[[MAKE_TENSOR_PTR_B:.*]] = tt.make_tensor_ptr {{.*}} : > + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c256_i32] {order = array} : > + // CHECK: %[[RES:.*]]:3 = scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[ARG5:.*]] = %[[MAKE_TENSOR_PTR_A]], %[[ARG6:.*]] = %[[MAKE_TENSOR_PTR_B]]) + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[A_LOAD:.*]] = tt.load %[[ARG5]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #[[$SUBGROUP_BLOCK_A]]> -> tensor<256x32xf16, #blocked1> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %[[ARG6]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #[[$SUBGROUP_BLOCK_B]]> -> tensor<32x256xf16, #blocked2> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: %[[ADVANCE_A:.*]] = tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: %[[ADVANCE_B:.*]] = tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + // CHECK: scf.yield {{.*}}, %[[ADVANCE_A]], %[[ADVANCE_B]] + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %c256_i32] {order = array} : > + // CHECK aritch.truncf %[[RES]]#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + +// ----- + +// COM: Test while loop / nested tt.advance +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +// CHECK-DAG: #[[$SUBGROUP_2D_BLOCK:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr) { + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + + // CHECK: %[[A_PTR:.*]] = tt.make_tensor_ptr %arg0, {{.*}} : > + %a_ptr = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%c256_i32, %c0_i32] {order = array} : > + + // CHECK: scf.while {{.*}} : (!tt.ptr>) -> !tt.ptr> + %1 = scf.while (%a_ptr_crt = %a_ptr) : (!tt.ptr>) -> (!tt.ptr>) { + %2 = "dummy.evaluate_condition"() : () -> i1 + // CHECK: scf.condition({{.*}}) {{.*}} : !tt.ptr> + scf.condition(%2) %a_ptr_crt : !tt.ptr> + } do { + ^bb0(%a_ptr_crt: !tt.ptr>): + // CHECK: ^bb0({{.*}}: !tt.ptr>): + + // CHECK: %[[A_LOAD:.*]] = tt.load {{.*}} : !tt.ptr> + %3 = tt.load %a_ptr_crt {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]> -> tensor<256x32xf16, #[[$BLOCKED]]> + // CHECK: ttg.convert_layout {{.*}} : tensor<256x32xf16, #[[$BLOCKED]]> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> + %4 = ttg.convert_layout %3 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + + %cstB = arith.constant dense<0.000000e+00> : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]> + %5 = tt.dot %4, %cstB, %cst, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked1> + // CHECK: tt.advance {{.*}} : > + %7 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : > + + // CHECK: scf.yield {{.*}} : !tt.ptr> + scf.yield %a_ptr_crt : !tt.ptr> + } + tt.return + } +} + +// ----- + +// COM: test complex control flow +// COM: Note that instead of using tt.advance we make a new tensor ptr each time. This is nice, because it lets us test that we can find MakeTensorPtr op inside the scf.if. +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK-DAG: #[[$SUBGROUP_BLOCK_A:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}> +// CHECK-DAG: #[[$SUBGROUP_BLOCK_B:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 16], numBlocks=2, order=[0, 1], kWidth=2, threadsPerWarp=16}> +// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} { +// CHECK-LABEL: @matmul_change_block_ptr_in_prologue +tt.func @matmul_change_block_ptr_in_prologue(%a_base: !tt.ptr, + %b_base: !tt.ptr) { + %c0_i64 = arith.constant 0 : i64 + %c1_i64 = arith.constant 1 : i64 + %k_tiles = arith.constant 32 : i64 + %true = arith.constant true + %false = arith.constant false + + %zero = arith.constant dense<0.0> : tensor<128x128xf32, #blocked> + + // CHECK: %[[A_UNDEF:.*]] = ub.poison : !tt.ptr> + // CHECK: %[[B_UNDEF:.*]] = ub.poison : !tt.ptr> + %a_ptr_undef = ub.poison : !tt.ptr> + %b_ptr_undef = ub.poison : !tt.ptr> + // CHECK: scf.for {{.*}} iter_args({{.*}} = {{.*}}, %[[A_PTR:.*]] = %[[A_UNDEF]], %[[B_PTR:.*]] = %[[B_UNDEF]]) + scf.for %k = %c0_i64 to %k_tiles step %c1_i64 iter_args(%acc = %zero, %flag = %true, %a_ptr = %a_ptr_undef, %b_ptr = %b_ptr_undef) -> (tensor<128x128xf32, #blocked>, i1, !tt.ptr>, !tt.ptr>) : i64 { + %do_prologue = "prologue_cond"(%k) : (i64) -> i1 + // CHECK: %[[PTRS:.*]]:2 = scf.if {{.*}} -> (!tt.ptr>, !tt.ptr>) + %cur_a_ptr, %cur_b_ptr = scf.if %do_prologue -> (!tt.ptr>, !tt.ptr>) { + %off_m, %off_n, %off_k = "get_offsets"(%k) : (i64) -> (i32, i32, i32) + // CHECK tt.make_tensor_ptr {{.*}} : > + %next_a_ptr = tt.make_tensor_ptr %a_base, [%k, %k], [%c1_i64, %c1_i64], [%off_m, %off_k] {order = array} : > + // CHECK tt.make_tensor_ptr {{.*}} : > + %next_b_ptr = tt.make_tensor_ptr %b_base, [%k, %k], [%c1_i64, %c1_i64], [%off_n, %off_k] {order = array} : > + // CHECK: scf.yield {{.*}} : !tt.ptr>, !tt.ptr> + scf.yield %next_a_ptr, %next_b_ptr : !tt.ptr>, !tt.ptr> + } else { + // CHECK: scf.yield {{.*}} : !tt.ptr>, !tt.ptr> + scf.yield %a_ptr, %b_ptr : !tt.ptr>, !tt.ptr> + } + + // CHECK: %[[A:.*]] = tt.load %[[PTRS]]#0 {{.*}} : !tt.ptr> + %a = tt.load %cur_a_ptr {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: ttg.convert_layout %[[A]] : tensor<128x64xf16, #[[$SUBGROUP_BLOCK_A]]> -> tensor<128x64xf16, #blocked1> + // CHECK: %[[B:.*]] = tt.load %[[PTRS]]#1 {{.*}} : !tt.ptr> + %b = tt.load %cur_b_ptr {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B]] : tensor<64x128xf16, #[[$SUBGROUP_BLOCK_B]]> -> tensor<64x128xf16, #blocked2> + %a_dot = ttg.convert_layout %a : tensor<128x64xf16, #blocked1> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %b_dot = ttg.convert_layout %b : tensor<64x128xf16, #blocked2> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %a_dot_dpas = ttg.convert_layout %a_dot : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %b_dot_dpas = ttg.convert_layout %b_dot : tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %accum = ttg.convert_layout %acc : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma> + %c = tt.dot %a_dot_dpas, %b_dot_dpas, %accum, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<128x128xf32, #mma> + %c_out = ttg.convert_layout %c : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked> + + %do_epilogue = arith.cmpi eq, %k, %c0_i64 : i64 + %use_acc = arith.select %do_epilogue, %false, %true : i1 + scf.if %do_epilogue { + "acc_user"(%c_out) : (tensor<128x128xf32, #blocked>) -> () + } + // CHECK: scf.yield {{.*}} : {{.*}}, i1, !tt.ptr>, !tt.ptr> + scf.yield %c_out, %use_acc, %cur_a_ptr, %cur_b_ptr : tensor<128x128xf32, #blocked>, i1, !tt.ptr>, !tt.ptr> + } + + tt.return + } +} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 3ceff03a17..91dd358217 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -280,6 +280,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) + intel.passes.ttgpuir.add_optimize_block_load_encoding(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_optimize_dot_operands(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt)) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index c20224aaee..c350c13bc7 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -409,4 +409,22 @@ def TritonIntelGPUReduceVariableLiveness "mlir::scf::SCFDialect", "mlir::arith::ArithDialect"]; } + +def TritonIntelGPUOptimizeBlockIOEncodingPass + : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> { + let summary = "Set encodings on candidates for Subgroup 2D Block IO ops"; + + let description = [{ + Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. + + The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the + encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a + ConvertLayout op to the existing encoding replaces the result of the LoadOp. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::gpu::intel::TritonIntelGPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index b8cb96cfa0..bb32041127 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeBlockIOEncoding.cpp OptimizeDotOperands.cpp OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp new file mode 100644 index 0000000000..29c48b00b9 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -0,0 +1,268 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/PriorityWorklist.h" + +namespace ttg = mlir::triton::gpu; +namespace ttgi = mlir::triton::gpu::intel; + +namespace mlir { +namespace triton { +namespace gpu::intel { + +#define DEBUG_TYPE "tritongpu-optimize-block-encoding" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +struct EncodingInfo { + Attribute desiredEncoding; + + bool operator==(const EncodingInfo &other) const { + return desiredEncoding == other.desiredEncoding; + } +}; + +/** + * The algorithm here takes inspiration from + * TritonNVIDIAGPU::OptimizeDescriptorEncoding. The idea is to iterate the + * def-use chain in both directions starting from the Load Op. We store the + * values that need to be updated along with the new encoding in the + * `valueToEncodingInfo` MapVector. After all value/encoding pairs have been + * determined, we update the encoding for each value, adding a conversion to + * the existing Load Op result layout for users of the load. + */ +void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { + auto loadOp = cast(op); + auto loadPtrType = cast(loadOp->getOperand(0).getType()); + auto addressSpace = loadPtrType.getAddressSpace(); + + llvm::MapVector, EncodingInfo> valueToEncodingInfo; + llvm::PriorityWorklist> worklist; + + auto updateEncoding = [&](ArrayRef ptrValues, EncodingInfo info) { + for (auto value : ptrValues) { + auto typedVal = cast>(value); + auto itr = valueToEncodingInfo.find(typedVal); + if (itr == valueToEncodingInfo.end()) { + LLVM_DEBUG(DBGS() << "Add encoding " << info.desiredEncoding + << " for value " << typedVal << "\n"); + valueToEncodingInfo[typedVal] = info; + worklist.insert(typedVal); + } else { + LLVM_DEBUG(DBGS() << "Found existing encoding info " + << itr->second.desiredEncoding << " for value " + << typedVal << ". Ensure new encoding " + << info.desiredEncoding << " matches.\n"); + assert(itr->second == info && "already visited encoding info for " + "value, expected them to be equal!"); + continue; + } + } + }; + + worklist.insert(cast>(loadOp->getOperand(0))); + + // 1. Starting from the Load Op, propagate encoding info up and down the + // def-use chain. + while (!worklist.empty()) { + auto crtValue = worklist.pop_back_val(); + + // Propagate to users + for (OpOperand &use : crtValue.getUses()) { + auto op = use.getOwner(); + if (isa(op)) { + auto offset = 3 * isa(op); + auto vals = getTiedArgs(op, use.getOperandNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } else if (isa(op)) { + auto vals = getTiedArgs(op->getParentOp(), use.getOperandNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } else if (isa(op)) { + // The operand will be updated when the MakeTensorPtr op result is + // updated. Make sure the result type matches. + for (auto result : op->getResults()) + if (auto desc = dyn_cast>(result)) + updateEncoding(desc, EncodingInfo{encoding}); + } + } + + // Propagate to defining ops + if (auto opResult = dyn_cast(crtValue)) { + auto definingOp = opResult.getOwner(); + if (isa(definingOp)) { + auto vals = getTiedArgs(definingOp, opResult.getResultNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } else if (auto blockArg = dyn_cast(crtValue)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + auto offset = isa(parentOp); + auto vals = getTiedArgs(parentOp, blockArg.getArgNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + } + + // 2. Update the type for each value in-place. Add a ConvertLayout Op after + // any loads which require conversion to the existing layout for the loaded + // value. + for (auto &[val, einfo] : valueToEncodingInfo) { + Attribute newEncoding = einfo.desiredEncoding; + LLVM_DEBUG(DBGS() << "Rewrite encoding to " << newEncoding << " for value " + << val << "\n"); + + PointerType oldType = val.getType(); + auto oldTensorTy = cast(oldType.getPointeeType()); + auto newTensorTy = RankedTensorType::get( + oldTensorTy.getShape(), oldTensorTy.getElementType(), newEncoding); + + val.setType(PointerType::get(newTensorTy, oldType.getAddressSpace())); + for (auto user : val.getUsers()) { + if (auto loadOp = dyn_cast(user)) { + + OpBuilder builder(loadOp); + auto oldLoadType = loadOp.getType(); + Value result = loadOp.getResult(); + + builder.setInsertionPointAfter(loadOp); + auto cvt = builder.create(loadOp.getLoc(), + result.getType(), result); + LLVM_DEBUG(DBGS() << "Added convert Op:\n" + << cvt << " after Load Op:\n" + << loadOp << "\n"); + result.setType(newTensorTy); + + result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation()); + } + } + } +} + +} // namespace + +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEBLOCKIOENCODINGPASS +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +class TritonIntelGPUOptimizeBlockIOEncodingPass + : public impl::TritonIntelGPUOptimizeBlockIOEncodingPassBase< + TritonIntelGPUOptimizeBlockIOEncodingPass> { + + void getSubgroup2DBlockLayoutForOperand( + Value operand, DpasEncodingAttr dpasLayout, + llvm::MapVector &layoutMap) { + auto isCandidateLoad = [](Value v) -> LoadOp { + // Peel out the original cvt dot_op<..., #blocked> + // and any other potential cvt/trans ops + while (true) { + if (auto cvtOp = v.getDefiningOp()) { + v = cvtOp.getSrc(); + continue; + } + if (auto transOp = v.getDefiningOp()) { + v = transOp.getSrc(); + continue; + } + break; + } + return dyn_cast(v.getDefiningOp()); + }; + + LoadOp loadOp = isCandidateLoad(operand); + if (!loadOp) + return; + + auto dotOperandType = cast(operand.getType()); + auto dotOperandEncoding = + cast(dotOperandType.getEncoding()); + // layout width is determined by the DPAS operand encoding width + const int kWidth = dotOperandEncoding.getKWidth(); + + Attribute blockIOAttr = + loadOp->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (!blockIOAttr) + return; + + // get the MakeTensorPtr Op for the load + Value ptr = loadOp.getPtr(); + if (!isTensorPointerType(ptr.getType())) { + // TODO: support tensor of pointer loads + LLVM_DEBUG(DBGS() << "Ptr\n" + << ptr << " for Load Op:\n" + << loadOp + << "\nincompatible with Subgroup 2D Block Layout.\n"); + return; + } + LLVM_DEBUG(DBGS() << "Retrieving tensor ptr op for ptr " << ptr << "\n"); + MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr); + LLVM_DEBUG(DBGS() << "Rerwrite encoding for block ptr op " + << makeTensorPtrOp << "\n"); + + auto oldTensorPtrType = cast(makeTensorPtrOp.getType()); + auto oldTensorType = + cast(oldTensorPtrType.getPointeeType()); + + auto CTALayout = getCTALayout(dpasLayout); + const unsigned elemSizeInBits = + oldTensorType.getElementType().getIntOrFloatBitWidth(); + + auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( + cast(dotOperandEncoding), + oldTensorType.getShape(), + blockIOAttr == StringAttr::get(&getContext(), "row_major"), + elemSizeInBits / 8, &getContext()); + SmallVector instrShape{tileParams[0], tileParams[1]}; + const unsigned vBlocks = tileParams[2]; + + auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get( + &getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape, + tileParams[2], + getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ 2, + /*kContig*/ true), + kWidth, dpasLayout.getThreadsPerWarp()); + + LLVM_DEBUG(DBGS() << "Generated new encoding: " << subgroup2DBlockEncoding + << " for op : " << loadOp << "\n"); + + layoutMap[loadOp] = subgroup2DBlockEncoding; + } + +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + // Step 1. Find all loads which are candidates for conversion to Subgroup 2D + // Block Encoding. To be a candidate load, a load must be consumed by a Dot + // Op and the load operand must be a block ptr (produced by a MakeTensorPtr + // Op). Currently we look for loads with the "block_io" attribute but we + // could consider moving that logic to this pass later. We place the load + // and the candidate encoding into the layout map for propagation in step 2 + llvm::MapVector layoutMap; + m.walk([&](DotOp dotOp) { + auto dotOpType = cast(dotOp.getResult().getType()); + auto dpasLayout = dyn_cast(dotOpType.getEncoding()); + if (!dpasLayout) + return; + + getSubgroup2DBlockLayoutForOperand(dotOp.getA(), dpasLayout, layoutMap); + getSubgroup2DBlockLayoutForOperand(dotOp.getB(), dpasLayout, layoutMap); + }); + + // Step 2. Rewrite MakeTensorPtr to use the new layout and propagate the + // change through the def-use chain, terminating at the Load Op. We add a + // ConvertLayout Op after the Load Op to convert back to the original + // layout. Subgroup2DBlockEncoding layouts will be chosen as anchor layouts + // in RemoveLayoutConversions, and a subsequent run of + // RemoveLayoutConversions after this pass cleans up intermediate layout + // conversions and removes the original Load Op encoding. + for (auto &kv : layoutMap) { + rewriteTensorLayoutsForOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu::intel +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 1aeae8f4d3..b26c757072 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -119,6 +119,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createTritonIntelGPUReduceDataDuplication); ADD_PASS_WRAPPER_0("add_materialize_block_pointer", gpu::intel::createTritonIntelGPUMaterializeBlockPointer); + ADD_PASS_WRAPPER_0( + "add_optimize_block_load_encoding", + gpu::intel::createTritonIntelGPUOptimizeBlockIOEncodingPass); ADD_PASS_WRAPPER_0("add_optimize_reduction_locality", gpu::intel::createTritonIntelGPUOptimizeReductionLocality); ADD_PASS_WRAPPER_0("add_reduce_variable_liveness",