diff --git a/python/test/unit/intel/test_block_load.py b/python/test/unit/intel/test_block_load.py index 45ce9e5d8c..59e7ecf7ff 100644 --- a/python/test/unit/intel/test_block_load.py +++ b/python/test/unit/intel/test_block_load.py @@ -8,6 +8,73 @@ from triton._internal_testing import is_xpu +def test_block_load_subgroup_layout(device, tmp_path: pathlib.Path): + M = 256 + N = 32 + A_width = 1 + B_width = 2 + transpose = False + ty = "f16" + block_io = "row_major" + dtype_str = "float16" + + layouts = """ + #dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2]}> + #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16}> + #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> + """ + + ir = layouts + f""" + module attributes {{ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32}} {{ + tt.func public @block_load_dpas_layout(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}, %arg3: !tt.ptr<{ty}> {{tt.divisibility = 16: i32}}) attributes {{noinline = false}} {{ + %0 = tt.get_program_id x : i32 + %M_i64 = arith.constant {M} : i64 + %N_i64 = arith.constant {N} : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + + // A matrix + %1 = tt.make_tensor_ptr %arg0, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%0, %c0_i32] {{order = array}} : > + %2 = tt.load %1 {{boundaryCheck = array, ttig.block_io = "row_major"}} : !tt.ptr> + %20 = ttg.convert_layout %2 : tensor<{M}x{N}x{ty}, #mma> -> tensor<{M}x{N}x{ty}, #ttg.dot_op<{{opIdx = 0, parent = #dpas, kWidth = {A_width}}}>> + %3 = tt.make_tensor_ptr %arg1, [%M_i64, %N_i64], [%N_i64, %c1_i64], [%0, %c0_i32] {{order = array}} : >> + tt.store %3, %20 {{boundaryCheck = array}} : !tt.ptr>> + + // B matrix + %4 = tt.make_tensor_ptr %arg2, [%N_i64, %M_i64], {"[%c1_i64, %N_i64]" if transpose else "[%M_i64, %c1_i64]"}, [%c0_i32, %0] {{order = array}} : > + %5 = tt.load %4 {{boundaryCheck = array, ttig.block_io = "{block_io}" }} : !tt.ptr> + %50 = ttg.convert_layout %5 : tensor<{N}x{M}x{ty}, #mma1> -> tensor<{N}x{M}x{ty}, #ttg.dot_op<{{opIdx = 1, parent = #dpas, kWidth = {B_width}}}>> + %6 = tt.make_tensor_ptr %arg3, [%N_i64, %M_i64], {"[%c1_i64, %N_i64]" if transpose else "[%M_i64, %c1_i64]"}, [%c0_i32, %0] {{order = array}} : >> + tt.store %6, %50 {{boundaryCheck = array}} : !tt.ptr>> + + tt.return + }} + }} + """ + + torch_dtype = getattr(torch, dtype_str) + if torch_dtype.is_floating_point: + a = torch.arange(0, M * N, dtype=torch_dtype, device=device).reshape((M, N)) + b = torch.arange(0, M * N, dtype=torch_dtype, device=device).reshape((N, M)) + else: + a = torch.randint(low=-127, high=128, size=(M, N), dtype=torch_dtype, device=device) + b = torch.randint(low=-127, high=128, size=(N, M), dtype=torch_dtype, device=device) + + x = torch.empty_like(a) + y = torch.empty_like(b.T if transpose else b) + + temp_file = tmp_path / "test_block_load_dpas_layout.ttgir" + temp_file.write_text(ir) + kernel = triton.compile(str(temp_file)) + + kernel[(1, 1, 1)](a, x, b, y) + + print(a.int()) + print(x.int()) + assert torch.equal(a, x) + assert torch.equal(b.T if transpose else b, y) + + @pytest.mark.parametrize("M, N", [[256, 64], [256, 32], [128, 32], [128, 16], [128, 8], [64, 64], [64, 32], [32, 32], [16, 64]]) @pytest.mark.parametrize("dtype_str", ["float32", "float16", "int8"]) diff --git a/test/TritonIntelGPU/optimize-block-io-encoding.mlir b/test/TritonIntelGPU/optimize-block-io-encoding.mlir new file mode 100644 index 0000000000..01f88d4f57 --- /dev/null +++ b/test/TritonIntelGPU/optimize-block-io-encoding.mlir @@ -0,0 +1,197 @@ +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --tritonintelgpu-optimize-block-io-encoding | FileCheck %s + +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16} +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16} +// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + +// ----- + +// COM: Dot Operand B transpose is supported +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 4], warpsPerCTA = [1, 32], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks = 2, isTransposed = false, order = [1, 0], kWidth = 1, threadsPerWarp = 16}> +// CHECK: #mma1 = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [16, 8], numBlocks = 1, isTransposed = true, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> +// CHECK: #mma2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + // CHECK: tt.make_tensor_ptr {{.*}} : > + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + // CHECK: %[[A_LOAD:.*]] = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #mma> -> tensor<256x32xf16, #blocked1> + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma1> -> tensor<32x256xf16, #blocked2> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma2, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma2, kWidth = 2}>> -> tensor<256x256xf32, #mma2> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked3> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} + +// ----- + +// COM: Dot operand A transpose currently not supported by subgroup 2d block io encoding +#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 1], warpsPerCTA = [2, 16], order = [0, 1]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}> +// CHECK: #mma = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [32, 16], numBlocks = 2, isTransposed = false, order = [0, 1], kWidth = 2, threadsPerWarp = 16}> +// CHECK: #mma1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +// CHECK-NOT: #mma2 +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} { + tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c4_i32 = arith.constant 4 : i32 + %c256_i32 = arith.constant 256 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c5120_i64 = arith.constant 5120 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c4096_i64 = arith.constant 4096 : i64 + %c32_i32 = arith.constant 32 : i32 + %c64_i32 = arith.constant 64 : i32 + %c5120_i32 = arith.constant 5120 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked> + %0 = tt.get_program_id x : i32 + %1 = arith.divsi %0, %c64_i32 : i32 + %2 = arith.muli %1, %c4_i32 : i32 + %3 = arith.subi %c4_i32, %2 : i32 + %4 = arith.minsi %3, %c4_i32 : i32 + %5 = arith.remsi %0, %4 : i32 + %6 = arith.addi %2, %5 : i32 + %7 = arith.remsi %0, %c64_i32 : i32 + %8 = arith.divsi %7, %4 : i32 + %9 = arith.muli %6, %c256_i32 : i32 + %10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c1_i64, %c1024_i64], [%9, %c0_i32] {order = array} : > + %11 = arith.muli %8, %c256_i32 : i32 + %12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%c0_i32, %11] {order = array} : > + %13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr>) : i32 { + // CHECK: {{.*}} = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + %17 = tt.load %arg5 {boundaryCheck = array, ttig.block_io = "column_major"} : !tt.ptr> + // CHECK: %[[B_LOAD:.*]] = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: {{.*}} = ttg.convert_layout %[[B_LOAD]] : tensor<32x256xf16, #mma> -> tensor<32x256xf16, #blocked2> + %18 = tt.load %arg6 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + %19 = ttg.convert_layout %17 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> + %20 = ttg.convert_layout %18 : tensor<32x256xf16, #blocked2> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> + %21 = ttg.convert_layout %arg4 : tensor<256x256xf32, #blocked> -> tensor<256x256xf32, #mma> + %22 = ttg.convert_layout %19 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> + %23 = ttg.convert_layout %20 : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma1, kWidth = 2}>> -> tensor<256x256xf32, #mma1> + %24 = tt.dot %22, %23, %21, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %25 = ttg.convert_layout %24 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + // CHECK: tt.advance {{.*}} : > + %26 = tt.advance %arg5, [%c0_i32, %c32_i32] : > + // CHECK: tt.advance {{.*}} : > + %27 = tt.advance %arg6, [%c32_i32, %c0_i32] : > + scf.yield %25, %26, %27 : tensor<256x256xf32, #blocked>, !tt.ptr>, !tt.ptr> + } + %14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array} : > + %15 = arith.truncf %13#0 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %16 = ttg.convert_layout %15 : tensor<256x256xf16, #blocked> -> tensor<256x256xf16, #blocked2> + tt.store %14, %16 {boundaryCheck = array} : !tt.ptr> + tt.return + } +} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index d225494f75..2c69390029 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -280,6 +280,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) + intel.passes.ttgpuir.add_optimize_block_load_encoding(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, XPUBackend.get_split_barrier_scope(opt)) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td index 254ea42b47..bb3198c3ce 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td @@ -297,6 +297,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", For the layout, the following parameters are required: - `instrShape` : contains the (height, width) block parameters for the block io operation - `numBlocks` : the block count parameter allows a single load to load multiple blocks in row-major order (useful for increasing cache line utilization) + - `isTransposed` : indicates whether the data should be transposed post-load. The `instrShape` describes the shape of the data to load pre-transpose, i.e. if this is true then the output from the instruction (load + tranpose) will be the transposed `instrShape`. - `threadsPerWarp` : currently a scalar, this parameter allows us to support different subgroup / warp configurations. Because the 2d block io operation is a subgroup operation, the size of the subgroup is important in determining the ordering of the loaded tensor. - `warpsPerCTA` : the number of warps per block / subgroups per workgroup and their distribution - `order` : The order within the block, used to determine along which dimension to broadcast. @@ -310,6 +311,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", "CTALayoutAttr":$CTALayout, ArrayRefParameter<"unsigned">:$instrShape, "unsigned":$numBlocks, + "bool":$isTransposed, ArrayRefParameter<"unsigned">:$order, "unsigned":$kWidth, "unsigned":$threadsPerWarp @@ -317,7 +319,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding", let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getRepOrderForOperand(int opIdx) const; - static SmallVector getInstrShapeForLayout(DistributedEncodingTrait layout, ArrayRef shape, bool memoryRowMajor, unsigned kWidth, MLIRContext* context); + static SmallVector getInstrShapeForLayout(DistributedEncodingTrait layout, ArrayRef shape, bool memoryRowMajor, bool isTransposed, unsigned kWidth, MLIRContext* context); }]; let hasCustomAssemblyFormat = 1; diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index c20224aaee..91625ba0be 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -409,4 +409,15 @@ def TritonIntelGPUReduceVariableLiveness "mlir::scf::SCFDialect", "mlir::arith::ArithDialect"]; } + +def TritonIntelGPUOptimizeBlockIOEncodingPass : Pass<"tritonintelgpu-optimize-block-io-encoding", "mlir::ModuleOp"> { + let summary = "Set encodings on candidates for Subgroup 2D Block IO ops"; + + let description = [{ + Set the Subgroup2DBlock encoding on tensor ptr types that are candidates for Subgroup 2D Block IO lowering. The goal is to change the tensor ptr type to use the new encoding so the LoadOp will use the new encoding, allowing the encoding to be an anchor layout during RemoveLayoutConversions. To avoid duplicating work in RemoveLayoutConversions, a ConvertLayout op to the existing encoding replaces the result of the LoadOp. + }]; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::gpu::intel::TritonIntelGPUDialect", "mlir::triton::TritonDialect"]; +} + #endif // TRITON_INTEL_GPU_PASSES diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 9ab7baaa71..c356f07f20 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -33,6 +33,15 @@ Attribute inferSrcEncoding(Operation *op, Attribute encoding); // Retuns true if the operation is an expensive load or store operation. bool isExpensiveLoadOrStore(Operation *op); +// Returns true if the conversion between tensor types should be a no-op. Will +// be removed once layout conversion for BlockIO types is lifted from +// LoadStoreOpToLLVM.cpp +bool isBlockIONoOpConversion(RankedTensorType srcType, + RankedTensorType dstType); + +// Returns true if the tensor type has a subgroup 2d block io encoding +bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType); + // Returns true if the tensor type has a dot dpas encoding. bool hasDotDpasEncoding(RankedTensorType tensorType); diff --git a/third_party/intel/lib/Analysis/Allocation.cpp b/third_party/intel/lib/Analysis/Allocation.cpp index 8c9cfe5147..e6cd3df5d2 100644 --- a/third_party/intel/lib/Analysis/Allocation.cpp +++ b/third_party/intel/lib/Analysis/Allocation.cpp @@ -1,5 +1,6 @@ #include "intel/include/Analysis/Allocation.h" #include "intel/include/Analysis/Utility.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" // isBlockIONoOpConversion #include "triton/Dialect/Triton/IR/Utility.h" #include "llvm/ADT/TypeSwitch.h" @@ -11,6 +12,9 @@ constexpr unsigned invalidSize = -1; unsigned allocationAnalysisScratchSizeFn(gpu::ConvertLayoutOp convertLayout) { RankedTensorType srcTy = convertLayout.getSrc().getType(); RankedTensorType dstTy = convertLayout.getResult().getType(); + + if (gpu::intel::isBlockIONoOpConversion(srcTy, dstTy)) + return 0; if (gpu::intel::cvtIsSubGroupShuffle(srcTy, dstTy)) return 0; if (gpu::intel::cvtIsSubGroupTranspose(srcTy, dstTy)) { diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 138fccf6c0..8fe6d88427 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -59,6 +59,17 @@ static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, return success(); } +static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr, + bool &value, StringRef desc) { + auto boolAttr = mlir::dyn_cast(attr); + if (!boolAttr) { + parser.emitError(parser.getNameLoc(), "expected a bool type in ") << desc; + return failure(); + } + value = boolAttr.getValue(); + return success(); +} + // parse an array of integers static LogicalResult parseIntArrayAttr(AsmParser &parser, const NamedAttribute &attr, @@ -83,6 +94,11 @@ static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr, return parseIntAttrValue(parser, attr.getValue(), value, desc); }; +static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr, + bool &value, StringRef desc) { + return parseBoolAttrValue(parser, attr.getValue(), value, desc); +}; + //===----------------------------------------------------------------------===// // Attribute methods //===----------------------------------------------------------------------===// @@ -531,8 +547,8 @@ void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer, LogicalResult Subgroup2DBlockEncodingAttr::verify( function_ref emitError, ArrayRef warpsPerCTA, CTALayoutAttr CTALayout, - ArrayRef instrShape, unsigned numBlocks, ArrayRef order, - unsigned kWidth, unsigned threadsPerWarp) { + ArrayRef instrShape, unsigned numBlocks, bool isTransposed, + ArrayRef order, unsigned kWidth, unsigned threadsPerWarp) { if (instrShape.size() != 2) { return emitError() << "instrShape must be rank 2 but was: " << instrShape.size(); @@ -569,6 +585,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { std::optional> CTAOrder; SmallVector instrShape; unsigned numBlocks = 0; + bool isTransposed = false; SmallVector order; unsigned kWidth = 0; unsigned threadsPerWarp = 0; @@ -601,6 +618,10 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { if (parseUInt(parser, attr, numBlocks, "numBlocks").failed()) return {}; } + if (attr.getName() == "isTransposed") { + if (parseBool(parser, attr, isTransposed, "isTransposed").failed()) + return {}; + } if (attr.getName() == "order") { if (parseIntArrayAttr(parser, attr, order, "order").failed()) return {}; @@ -622,7 +643,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) { return parser.getChecked( parser.getContext(), warpsPerCTA, *CTALayout, instrShape, numBlocks, - order, kWidth, threadsPerWarp); + isTransposed, order, kWidth, threadsPerWarp); } SmallVector Subgroup2DBlockEncodingAttr::getRepOrder() const { @@ -652,9 +673,10 @@ void Subgroup2DBlockEncodingAttr::print(AsmPrinter &printer) const { maybePrintCTALayout(getContext(), printer, getCTALayout(), getRank()); printer << ", instrShape = [" << getInstrShape() - << "], numBlocks=" << getNumBlocks() << ", order=[" << getOrder() - << "], kWidth=" << getKWidth() - << ", threadsPerWarp=" << getThreadsPerWarp() << "}>"; + << "], numBlocks = " << getNumBlocks() + << ", isTransposed = " << getIsTransposed() << ", order = [" + << getOrder() << "], kWidth = " << getKWidth() + << ", threadsPerWarp = " << getThreadsPerWarp() << "}>"; } LinearLayout @@ -664,7 +686,8 @@ Subgroup2DBlockEncodingAttr::toLinearLayout(ArrayRef shape) const { SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( DistributedEncodingTrait layout, ArrayRef tensorShape, - bool memoryRowMajor, unsigned kWidth, MLIRContext *context) { + bool memoryRowMajor, bool isTransposed, unsigned kWidth, + MLIRContext *context) { const auto rank = tensorShape.size(); std::optional llEncoding = layout.toLinearLayout(tensorShape); @@ -672,13 +695,6 @@ SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( LinearEncodingAttr llAttr = LinearEncodingAttr::get(context, *llEncoding); SmallVector threadOrder = llAttr.getThreadOrder(); - const bool valueRowMajor = - (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); - assert((valueRowMajor || - (threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) && - "Only row_major or column_major is allowed"); - const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - auto dotEncodingAttr = dyn_cast(layout); const unsigned opIdx = dotEncodingAttr ? dotEncodingAttr.getOpIdx() : 2; @@ -725,7 +741,7 @@ SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( unsigned dpasOperandsPerTileY = isOperandA ? numReps[2] : repCluster[dimOuter]; - if (isTransposeRequired) { + if (isTransposed) { std::swap(tileWidth, tileHeight); const unsigned threadsPerWarp = dpasLayout.getThreadsPerWarp(); @@ -738,6 +754,11 @@ SmallVector Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( dpasOperandsPerTileY = 1; } + // PVC 2D load supports 32 rows at most. Load multiple dot operands in by + // enlarging the tileHeight. + dpasOperandsPerTileX = std::min(dpasOperandsPerTileX, 32 / tileHeight); + tileHeight = tileHeight * dpasOperandsPerTileX; + // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands // by enlarging the number of blocks. const unsigned totalBytesPerRowPerDPASOp = tileWidth * kWidth; diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index 64cc423629..ddf46f5f0a 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -602,11 +602,15 @@ subgroup2DBlockToLinearLayout(ArrayRef blockShape, assert(rank == layout.getRank() && "unexpected block shape rank, layout rank " "and block shape rank must be equal"); auto dimNames = standardOutDimNames(ctx, rank); - auto loadTileSize = layout.getInstrShape(); + auto loadTileSize = SmallVector(layout.getInstrShape()); + assert(loadTileSize.size() == 2); StringAttr kRegister = S("register"); StringAttr kLane = S("lane"); StringAttr kWarp = S("warp"); + if (layout.getIsTransposed()) + std::swap(loadTileSize[0], loadTileSize[1]); + // Start by creating register/lane bases corresponding to the desired load // tile size auto [regBases, laneBases] = createRegisterLaneBases( diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 0feb202b49..5e37c47c07 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -3,6 +3,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "intel/include/Analysis/Utility.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" namespace mlir::triton::gpu { namespace { @@ -24,7 +25,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion ConversionPatternRewriter &rewriter) const override { MLIRContext *ctx = op.getContext(); - auto srcTy = op.getSrc().getType(); + RankedTensorType srcTy = op.getSrc().getType(); auto dstTy = op.getType(); LinearLayout conversion = minimalCvtLayout(srcTy, dstTy); @@ -38,6 +39,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion assert(to_vector(conversion.getInDimNames()) == to_vector(conversion.getOutDimNames())); auto dims = conversion.getInDimNames(); + llvm::errs() << "dims for conversion: \n"; + for (auto dim : dims) + llvm::errs() << dim << "\n"; if (llvm::is_contained(dims, kLane)) { // If the operation is a supported sub-group shuffle, perform via shuffle // operations. diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 32be8bd222..b055530f18 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -23,6 +23,21 @@ using namespace mlir::triton::gpu::intel; namespace { +Value llPrintf(StringRef msg, ValueRange args, ArrayRef isSigned, + ConversionPatternRewriter &rewriter, + const mlir::triton::intel::TargetInfo &targetInfo) { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = targetInfo.getGlobalStringStart( + rewriter.getUnknownLoc(), rewriter, "printfFormat_", msgNewline, + /*addressSpace=*/TritonGEN::kUniformConstant); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args, + isSigned); + return msgValue; +} + Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { auto tb = TritonLLVMOpBuilder(loc, rewriter); if (a && b) { @@ -302,7 +317,8 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { // Only lower loadOp with dpas layout encoding. auto tensorTy = cast(op.getType()); - return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy); + return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy) || + hasSubgroup2DBlockEncoding(tensorTy); } template < @@ -342,6 +358,15 @@ struct BlockIOConversionBase : public LoadStoreConversionBase { : getDotEncoding(tensorTy).value().getParent()); } + static RankedTensorType getDpasTypeFromCVTOp(Value opResult) { + for (OpOperand user : opResult.getUsers()) { + if (auto cvt = dyn_cast(user.getOwner())) { + return cast(cvt.getResult().getType()); + } + } + llvm_unreachable("expected to find a cvt op with dpas layout"); + } + // Returns the pitch (stride in bytes) of \p ptr. Value getPitch(ConversionPatternRewriter &rewriter, Value ptr, const std::map, Value> &ptrs, @@ -1398,6 +1423,319 @@ struct LoadOpConversion oneMatrixPerLoadForBT(oneMatrixPerLoadForBT), useTileLoadLinearLayout(useTileLoadLinearLayout) {} + LogicalResult rewriteSubgroup2DBlockEncodingLoad( + triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value ptr = op.getPtr(); + assert(isTensorPointerType(ptr.getType()) && + "Expecting tensor pointer type"); + + Type resultType = op.getType(); + auto tensorType = cast(resultType); + assert(hasSubgroup2DBlockEncoding(tensorType) && + "load op passed to subgroup 2d block encoding load codegen must " + "have subgroup 2d block encoding"); + + LLVM_DEBUG(llvm::dbgs() + << "Lowering load op with Subgroup 2D Block Encoding: " << op + << "\n"); + + Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + Value mask = op.getMask(); + Value other = op.getOther(); + + const bool memoryRowMajor = isMemoryRowMajor(op); + + auto encoding = cast(tensorType.getEncoding()); + LinearLayout loadLayout = encoding.toLinearLayout(tensorType.getShape()); + llvm::errs() << "loadLayout: " << loadLayout << "\n"; + LinearEncodingAttr llAttr = + LinearEncodingAttr::get(rewriter.getContext(), loadLayout); + SmallVector threadOrder = llAttr.getThreadOrder(); + size_t rank = threadOrder.size(); + + if (rank != 2) { + op.emitWarning( + "Subgroup 2D Block Encoding only supports rank 2 tensors."); + return failure(); + } + + const bool isTransposeRequired = encoding.getIsTransposed(); + if (isTransposeRequired) { + // hrmm... + // loadLayout = + // loadLayout.transposeOuts(llvm::to_vector(llvm::reverse(loadLayout.getOutDimNames()))); + } + + auto instrShape = encoding.getInstrShape(); + const unsigned tileHeight = instrShape[0]; + const unsigned tileWidth = instrShape[1]; + const unsigned numBlocks = encoding.getNumBlocks(); + LLVM_DEBUG({ + llvm::dbgs() << "tileHeight = " << tileHeight << "\n"; + llvm::dbgs() << "tileWidth = " << tileWidth << "\n"; + llvm::dbgs() << "numBlocks = " << numBlocks << "\n"; + }); + + const ArrayRef tensorShape = tensorType.getShape(); + + auto warpsPerCTA = encoding.getWarpsPerCTA(); + LLVM_DEBUG({ + llvm::dbgs() << "warpsPerCTA: " << warpsPerCTA[0] << ", " + << warpsPerCTA[1] << "\n"; + }); + SmallVector dpasWarpsOrder = + getMatrixOrder(warpsPerCTA.size(), /*rowMajor*/ true); + + const unsigned threadsPerWarp = encoding.getThreadsPerWarp(); + LLVM_DEBUG(llvm::dbgs() << "threadsPerWarp = " << threadsPerWarp << "\n"); + auto warpOrder = llAttr.getWarpOrder(); + + unsigned dimOuter = warpOrder[0]; // TODO + + Value warpId = rewriter.create( + loc, i32_ty, + rewriter.create(loc, /*upperBound=*/nullptr)); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, dpasWarpsOrder); + unsigned outerDimWarpNum = + warpsPerCTA[dimOuter]; //, outerDimRequiredWarpNum); + LLVM_DEBUG(llvm::dbgs() << "outerDimWarpNum = " << outerDimWarpNum << "\n"); + Value outerDimWarpId = + b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); + + Type eltTy = tensorType.getElementType(); + + LLVMTypeConverter *typeConverter = getTypeConverter(); + Type valueElemTy = typeConverter->convertType(eltTy); + LLVM_DEBUG(llvm::dbgs() << "valueElemTy: " << valueElemTy << "\n"); + + auto printVector = [](auto vector, auto name) { + llvm::errs() << name << "\n"; + for (auto i : vector) { + llvm::errs() << i << "\n"; + } + }; + + printVector(threadOrder, "threadOrder"); + + printVector(warpOrder, "warpOrder"); + + auto shapePerCTA = getShapePerCTA(tensorType); + printVector(shapePerCTA, "shapePerCTA"); + + unsigned innerBlockSize = shapePerCTA.back(); + llvm::errs() << "innerBlockSize = " << innerBlockSize << "\n"; + unsigned contigDimSize = tileWidth * numBlocks; // true? + llvm::errs() << "contigDimSize = " << contigDimSize << "\n"; + + unsigned numMessagesPerRow = ceil(innerBlockSize, contigDimSize); + llvm::errs() << "numMessagesPerRow = " << numMessagesPerRow << "\n"; + + auto ctaSplitNum = getCTASplitNum(llAttr); + printVector(ctaSplitNum, "ctaSplitNum"); + + auto ctasPerCGA = getCTAsPerCGA(llAttr); + printVector(ctasPerCGA, "ctasPerCGA"); + + // legacy, don't use! + // auto ctaShape = llAttr.getShapePerCTATile(); + // printVector(ctaShape, "shape per CTA tile"); + + MLIRContext *ctx = rewriter.getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + auto kLane = StringAttr::get(ctx, "lane"); + auto kWarp = StringAttr::get(ctx, "warp"); + auto kBlock = StringAttr::get(ctx, "block"); + + auto basesPerReg = llAttr.basesPerDim(kRegister, /*skipBroadcast*/ true); + printVector(basesPerReg, "bases per reg"); + + auto basesPerRegNoBroadcast = + llAttr.basesPerDim(kRegister, /*skipBroadcast*/ false); + printVector(basesPerRegNoBroadcast, "bases per reg no broadcast"); + + auto basesPerLane = llAttr.basesPerDim(kLane); // threads per warp + printVector(basesPerLane, "bases per lane"); + + auto basesPerWarp = + llAttr.basesPerDim(kWarp, /*skipBrodcast*/ false); // warps per cta + printVector(basesPerWarp, "bases per warp"); + + auto threadsPerWarp2 = llAttr.getThreadsPerWarp(); + printVector(threadsPerWarp2, "threads per warp"); + + // auto warpsPerCTA = llAttr.getWarpsPerCTA(); + printVector(warpsPerCTA, "warps per cta"); + + auto contigPerThread = llAttr.getContigPerThread(); + printVector(contigPerThread, "contigPerThread"); + + auto contigPerWarp = llAttr.getContigPerWarp(); + printVector(contigPerWarp, "contigPerWarp"); + + // unsigned vec = getVectorSize(ptr); + // auto vecTensor = getVectorSize(tensorType); + // llvm::errs() << "vec tensor: " << vecTensor << "\n"; + unsigned elemsPerThread = getTotalElemsPerThread(ptr.getType()); + llvm::errs() << "elemsPerThread: " << elemsPerThread << "\n"; + unsigned elemsPerThreadTensorTy = getTotalElemsPerThread(tensorType); + llvm::errs() << "elemsPerThread tensor type: " << elemsPerThreadTensorTy + << "\n"; + + auto width = encoding.getKWidth(); + + unsigned vec = tileHeight * numBlocks; + llvm::errs() << "vec: " << vec << "\n"; + unsigned numValuesPerLoad = vec / width; + + unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); + Value elemSizeInBytes = b.i32_val(elemSizeInBits / 8); + llvm::errs() << "elemSizeInBits: " << elemSizeInBits << "\n"; + llvm::errs() << "elemSizeInBytes: " << elemSizeInBytes << "\n"; + + Type packedElemTy = valueElemTy; + if (width == 2) { + packedElemTy = i32_ty; // HACK + // elemSizeInBits = 32; + } + + Type load2DGenXType = LLVM::getVectorType(packedElemTy, numValuesPerLoad); + // Note: these end up being a mix of float/int vector types... + llvm::errs() << "load2DGenXType = " << load2DGenXType << "\n"; + + auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX, + offsetBaseY] = + getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter); + + Value pitch; + if (memoryRowMajor) { + pitch = b.trunc(i32_ty, rowStride); + } else { + // Column major memory. We need to swap the width and height because HW + // only support row major memory layout. + pitch = b.trunc(i32_ty, colStride); + std::swap(baseWidth, baseHeight); + } + baseWidth = b.trunc(i32_ty, baseWidth); + baseHeight = b.trunc(i32_ty, baseHeight); + + // Dispatch the load instructions from the perspective of a single lane. + unsigned numElems = elemsPerThreadTensorTy; + llvm::errs() << "numElems = " << numElems << "\n"; + + Value zero = b.i32_val(0); + auto baseOffsetForWarp = applyLinearLayout( + loc, rewriter, loadLayout, + {{kRegister, zero}, {kLane, zero}, {kWarp, warpId}, {kBlock, zero}}); + assert(baseOffsetForWarp.size() == 2); + + // probably always 0? just for fun look at warp 1 + auto baseOffset = + loadLayout.apply({{kRegister, 0}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + assert(baseOffset.size() == 2); + LLVM_DEBUG({ + llvm::dbgs() << "base offset: " << baseOffset[0].second << ", " + << baseOffset[1].second << "\n"; + }); + + SmallVector loadedVals; + for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) { + llvm::errs() << "dispatch load " << vecStart << "\n"; + + auto offset = loadLayout.apply( + {{kRegister, vecStart}, {kLane, 0}, {kWarp, 0}, {kBlock, 0}}); + assert(offset.size() == 2); + LLVM_DEBUG({ + llvm::dbgs() << "offset: " << offset[0].second << ", " + << offset[1].second << "\n"; + llvm::dbgs() << "offset - base offset: " + << offset[0].second - baseOffset[0].second << ", " + << offset[1].second - baseOffset[1].second << "\n"; + }); + + // Lane ID doesn't matter. + Value regIdVal = b.i32_val(vecStart); + auto offsetValues = applyLinearLayout(loc, rewriter, loadLayout, + {{kRegister, regIdVal}, + {kLane, zero}, + {kWarp, warpId}, + {kBlock, zero}}); + assert(offsetValues.size() == 2); + + Value offsetX = + b.sub(offsetValues[0].second, baseOffsetForWarp[0].second); + Value offsetY = + b.sub(offsetValues[1].second, baseOffsetForWarp[1].second); + if (warpOrder[0]) { + llvm::errs() << "adding to Y offset\n"; + // b matrix + offsetY = b.add( + b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/ 32)), offsetY); + } else { + llvm::errs() << "adding to X offset\n"; + // a matrix + offsetX = b.add( + b.mul(outerDimWarpId, b.i32_val(/*warpOuterStride*/ 32)), offsetX); + } + + std::swap(offsetX, offsetY); // TODO: remove? + + offsetX = b.add(offsetX, offsetBaseX); + offsetY = b.add(offsetY, offsetBaseY); + + auto load2dOp = rewriter.create( + loc, load2DGenXType, + /*ptr*/ base, + /*base_width*/ b.mul(baseWidth, elemSizeInBytes), + /*base_height*/ baseHeight, + /*base_pitch*/ b.mul(pitch, elemSizeInBytes), + /*x*/ b.trunc(i32_ty, offsetX), + /*y*/ b.trunc(i32_ty, offsetY), + /*elem_size_in_bits*/ elemSizeInBits, + /*tile_width*/ tileWidth, + /*tile_height*/ tileHeight, + /*v_blocks*/ numBlocks, + /*transpose*/ isTransposeRequired, + /*vnni_transform*/ width > 1 && !isTransposeRequired); + llvm::errs() << "Generated load2dOp: " << load2dOp << "\n"; + if (failed(load2dOp.verify())) { + // delete the op so that the verifier will not abort the pass + // pipeline later, as we can fail this path and try a different + // approach. + assert(false); + rewriter.eraseOp(load2dOp); + return failure(); + } + + llvm::errs() << "llvm type: " << load2DGenXType << "\n"; + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + + // Extract and store return values + Value load2dVec = + b.bitcast(load2dOp, LLVM::getVectorType(valueElemTy, vec)); + llvm::errs() << "bitcasted load vec: " << load2dVec << "\n"; + llvm::errs() << "vec size: " << vec << "\n"; + for (size_t i = 0; i < vec; i++) { + Value vecIdx = createIndexAttrConstant( + rewriter, loc, typeConverter->getIndexType(), i); + Value loaded = b.extract_element(valueElemTy, load2dVec, vecIdx); + loadedVals.push_back(loaded); + } + + } // end vec + + llvm::errs() << "opType: " << op.getType() << "\n"; + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + llvm::errs() << "result struct type: " << llvmResultStructTy << "\n"; + llvm::errs() << "number of load vals: " << loadedVals.size() << "\n"; + Value resultStruct = packLLElements(loc, typeConverter, loadedVals, + rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + return success(); + } + LogicalResult rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -1408,28 +1746,38 @@ struct LoadOpConversion if (!isLoadCandidate(op)) return failure(); + Type resultType = op.getType(); + auto tensorType = cast(resultType); + if (hasSubgroup2DBlockEncoding(tensorType)) { + auto ret = rewriteSubgroup2DBlockEncodingLoad(op, adaptor, rewriter); + assert(!ret.failed()); + return ret; + } + Location loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); Value mask = op.getMask(); Value other = op.getOther(); - Type resultType = op.getType(); - auto tensorType = cast(resultType); const bool memoryRowMajor = isMemoryRowMajor(op); - DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType); + + auto dpasTensorType = hasSubgroup2DBlockEncoding(tensorType) + ? getDpasTypeFromCVTOp(op.getResult()) + : tensorType; + DpasEncodingAttr dpasLayout = getDpasLayout(dpasTensorType); + + DpasEncodingAttr::OpIdx opIdx = getOpIdx(dpasTensorType); LLVM_DEBUG(llvm::dbgs() << "Tensor type for op " << int(opIdx) << ": " << tensorType << "\n"); - Attribute encoding = tensorType.getEncoding(); - std::optional llEncoding = - cast(encoding).toLinearLayout( - tensorType.getShape()); - assert(llEncoding.has_value() && "invalid dot layout to linear layout"); + auto encoding = cast(tensorType.getEncoding()); + LinearLayout llEncoding = encoding.toLinearLayout(tensorType.getShape()); LinearEncodingAttr llAttr = - LinearEncodingAttr::get(rewriter.getContext(), *llEncoding); + LinearEncodingAttr::get(rewriter.getContext(), llEncoding); SmallVector threadOrder = llAttr.getThreadOrder(); size_t rank = threadOrder.size(); + const bool valueRowMajor = (threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0); assert((valueRowMajor || @@ -1440,14 +1788,27 @@ struct LoadOpConversion Type eltTy = tensorType.getElementType(); unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); - auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( - cast(encoding), tensorType.getShape(), - memoryRowMajor, elemSizeInBits / 8, rewriter.getContext()); - unsigned tileHeight = tileParams[0]; - const unsigned tileWidth = tileParams[1]; - const unsigned vBlocks = tileParams[2]; + auto getTileParams = [&]() -> std::tuple { + if (hasSubgroup2DBlockEncoding(tensorType)) { + auto encoding = + cast(tensorType.getEncoding()); + auto shape = encoding.getInstrShape(); + return std::make_tuple(shape[0], shape[1], encoding.getNumBlocks()); + } else { + auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( + cast(encoding), tensorType.getShape(), + memoryRowMajor, isTransposeRequired, elemSizeInBits / 8, + rewriter.getContext()); + return std::make_tuple(tileParams[0], tileParams[1], tileParams[2]); + } + }; + auto [tileHeight, tileWidth, vBlocks] = getTileParams(); + LLVM_DEBUG({ + llvm::dbgs() << "tileHeight = " << tileHeight << "\n"; + llvm::dbgs() << "tileWidth = " << tileWidth << "\n"; + llvm::dbgs() << "vBlocks = " << vBlocks << "\n"; + }); - DpasEncodingAttr dpasLayout = getDpasLayout(tensorType); const ArrayRef tensorShape = tensorType.getShape(); unsigned numElems = getTotalElemsPerThread(resultType); SmallVector numReps = @@ -1599,6 +1960,8 @@ struct LoadOpConversion Type unpackedDPASOperandType = LLVM::getVectorType( typeConverter->convertType(eltTy), elemsPerLanePerDPASInst); + const unsigned origTileHeight = elemsPerDPASInst[threadOrder[rank - 1]]; + // By default, use the unpacked type for the 2D load result type. Type loadResultElemType = typeConverter->convertType(eltTy); bool usePackedType = false; @@ -1657,8 +2020,7 @@ struct LoadOpConversion << outerDimRequiredWarpNum << "\n"); unsigned outerDimWarpNum = std::min(warpsPerCTA[dimOuter], outerDimRequiredWarpNum); - LLVM_DEBUG(llvm::dbgs() - << "outerDimWarpNum = " << outerDimRequiredWarpNum << "\n"); + LLVM_DEBUG(llvm::dbgs() << "outerDimWarpNum = " << outerDimWarpNum << "\n"); Value outerDimWarpId = b.urem(multiDimWarpId[dimOuter], b.i32_val(outerDimWarpNum)); @@ -1682,67 +2044,6 @@ struct LoadOpConversion LLVM_DEBUG(llvm::dbgs() << "dpasTileToPackedIndicesRatio = " << dpasTileToPackedIndicesRatio << "\n"); - // Create the linear layout for the load. - // First, we create a tile layout corresponding to a single invocation of - // the DPAS instruction across all threads/work-items in a sub-group. The - // layout will later be expanded to cover multiple DPAS invocations - // (iteration) and multiple loads (load). - StringAttr kOffset = S("offset"); - StringAttr kIteration = S("iteration"); - StringAttr kLoad = S("load"); - - auto createTileLayout = [&](const SmallVectorImpl &threadOrder, - SmallVector tileShape) { - auto outDimNames = standardOutDimNames(ctx, tensorShape.size()); - LinearLayout layout = LinearLayout::empty(); - SmallVector kOffsetDims; - unsigned totalOffsets = 1; - assert(tileShape.size() == 2); // only support 2D layouts for now - - if (isTransposeRequired && opIdx == DpasEncodingAttr::OpIdx::OperandB) { - const unsigned widthDim = threadOrder[rank - 2]; - const unsigned origTileWidth = tileShape[widthDim]; - tileShape[widthDim] = origTileWidth / (32 / elemSizeInBits); - } - - for (int i = 0; i < tileShape.size(); i++) { - int dim = threadOrder[i]; - StringAttr kOffset = S("offset" + std::to_string(dim)); - - kOffsetDims.push_back(kOffset); - - assert(llvm::isPowerOf2_32(tileShape[dim])); - // reduce the offset dimension size by the number of elements packed in - // a single slot for the row wise dimension - const unsigned offsetDimSize = - (!isTransposeRequired && dim == 0) - ? tileShape[dim] / dpasTileToPackedIndicesRatio - : tileShape[dim]; - layout *= - LinearLayout::identity1D(offsetDimSize, kOffset, outDimNames[dim]); - totalOffsets *= offsetDimSize; - } - SmallVector newDims; - newDims.append(kOffsetDims.begin(), kOffsetDims.end()); - auto ret = layout.transposeIns(newDims); - ret = ret.transposeOuts(outDimNames); - return ret.reshapeIns({{kOffset, totalOffsets}}); - }; - auto tileLayout = createTileLayout(threadOrder, elemsPerDPASInst); - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout: " << tileLayout << "\n"; - for (size_t i = 0; i < tileLayout.getOutDimSize(dimOuterStr) * - tileLayout.getOutDimSize(dimInnerStr); - i += tileLayout.getOutDimSize(S("dim1"))) { - auto tensorVals = tileLayout.apply({{kOffset, i}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << i << " : " << tensorVals[0].second << ", " - << tensorVals[1].second << "\n"; - } - llvm::dbgs() << "tile layout done\n"; - }); - unsigned numOperandsOuterDimPerLoad = 1; unsigned numOperandsInnerDimPerLoad = 1; @@ -1777,11 +2078,9 @@ struct LoadOpConversion numOperandsPer2DloadN = 1; } - // TODO: move this logic to the instr shape computation - // PVC 2D load supports 32 rows at most. Load multiple dot operands in by - // enlarging the tileHeight. - numOperandsPer2DLoadM = std::min(numOperandsPer2DLoadM, 32 / tileHeight); - tileHeight = tileHeight * numOperandsPer2DLoadM; + numOperandsPer2DLoadM = + std::min(numOperandsPer2DLoadM, 32 / origTileHeight); + // tileHeight = tileHeight * numOperandsPer2DLoadM; // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands // by enlarging the vBlocks. @@ -1802,33 +2101,6 @@ struct LoadOpConversion llvm::dbgs() << "vBlocks = " << vBlocks << "\n"; }); - tileLayout *= LinearLayout::identity1D(numOperandsOuterDimPerLoad, - kIteration, dimOuterStr); - tileLayout *= - LinearLayout::identity1D(isTransposeRequired && oneMatrixPerLoadForBT - ? 1 - : numOperandsInnerDimPerLoad, - kIteration, dimInnerStr); - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout after adding iterations: " - << tileLayout << "\n"; - - for (size_t itr = 0; itr < tileLayout.getInDimSize(kIteration); itr++) { - auto printTileLayoutVals = [&](const size_t offset) { - auto tensorVals = - tileLayout.apply({{kOffset, offset}, {kIteration, itr}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << itr << ", " << offset << " : " << tensorVals[0].second - << ", " << tensorVals[1].second << "\n"; - }; - - printTileLayoutVals(0); - printTileLayoutVals(tileLayout.getInDimSize(kOffset) - 1); - } - llvm::dbgs() << "\n"; - }); - if (isTransposeRequired) std::swap(numOperandsOuterDimPerLoad, numOperandsInnerDimPerLoad); @@ -1865,87 +2137,6 @@ struct LoadOpConversion llvm::dbgs() << "numRepInner = " << numRepInner << "\n"; }); - // For the kLoad dimension we create the basis vector directly, which allows - // us to control the stride between loads and create a non-surjective - // layout. - auto bases = tileLayout.getBases(); - std::vector> newLoadBases; - - SmallVector> outDims; - for (auto [name, size] : - llvm::zip(tileLayout.getOutDimNames(), tileLayout.getOutDimSizes())) { - outDims.push_back(std::make_pair(name, size)); - } - assert(outDims[0].first == S("dim0")); - assert(outDims[1].first == S("dim1")); - - for (size_t i = 0; - i < llvm::Log2_32(numRepInner / numOperandsInnerDimPerLoad); i++) { - newLoadBases.push_back({0, static_cast((1 << i) * repKStride * - numOperandsInnerDimPerLoad)}); - outDims[1].second *= repKStride * numOperandsInnerDimPerLoad; - } - for (size_t i = 0; i < llvm::Log2_32(numLoadPerOutRepCluster); i++) { - newLoadBases.push_back({static_cast((1 << i) * repStride), 0}); - outDims[0].second *= repStride; - } - for (size_t i = 0; i < llvm::Log2_32(numRepOuter); i++) { - newLoadBases.push_back({static_cast((1 << i) * repOuterStride), 0}); - outDims[0].second *= repOuterStride; - } - - LLVM_DEBUG({ - llvm::dbgs() << "Created Load Bases:\n"; - for (auto &base : newLoadBases) { - assert(base.size() == 2); - llvm::dbgs() << base[0] << ", " << base[1] << "\n"; - } - }); - - LLVM_DEBUG({ - llvm::dbgs() << "New tile layout dimensions after adding load bases:\n"; - for (size_t i = 0; i < outDims.size(); i++) { - llvm::dbgs() << outDims[i].first << " = " << outDims[i].second << "\n"; - } - }); - - // Disable building the load layout if we are not going to use it. Building - // the layout manually can cause an error which would abort the pass - // pipeline and block us from getting debug info. - if (useTileLoadLinearLayout) { - // add the bases to the map and replace the tile layout with the new - // layout - bases[kLoad] = newLoadBases; - tileLayout = LinearLayout(bases, outDims, - /*requiredSurjective=*/false); - } else { - // when linear layouts are disabled generate a single load, so we can have - // some reference for linear layout output without generating a layout - // that could abort the pass pipeline - tileLayout *= LinearLayout::identity1D(1, kLoad, dimOuterStr); - } - - LLVM_DEBUG({ - llvm::dbgs() << "Block load tile layout after adding loads: " - << tileLayout << "\n"; - for (size_t load = 0; load < tileLayout.getInDimSize(kLoad); load++) { - for (size_t itr = 0; itr < tileLayout.getInDimSize(kIteration); itr++) { - auto printTileLayoutVals = [&](const size_t offset) { - auto tensorVals = tileLayout.apply( - {{kOffset, offset}, {kIteration, itr}, {kLoad, load}}); - assert(tensorVals.size() == 2); - llvm::dbgs() << load << ", " << itr << ", " << offset << " : " - << tensorVals[0].second << ", " << tensorVals[1].second - << "\n"; - }; - - printTileLayoutVals(0); - printTileLayoutVals(tileLayout.getInDimSize(kOffset) - 1); - } - llvm::dbgs() << "\n"; - } - }); - Value pitch; if (memoryRowMajor) { pitch = b.trunc(i32_ty, rowStride); @@ -1995,17 +2186,6 @@ struct LoadOpConversion k / numOperandsInnerDimPerLoad; LLVM_DEBUG(llvm::dbgs() << "loadIdx: " << loadIdx << "\n"); - const auto offset = tileLayout.apply( - {{kOffset, 0}, {kIteration, 0}, {kLoad, loadIdx}}); - assert(offset.size() == 2); - - const auto layoutOffsetX = offset[dimInner].second; - const auto layoutOffsetY = offset[dimOuter].second; - LLVM_DEBUG({ - llvm::dbgs() << "x offset ll: " << layoutOffsetX << "\n"; - llvm::dbgs() << "y offset ll: " << layoutOffsetY << "\n"; - }); - Value offsetX, offsetY; switch (opIdx) { case DpasEncodingAttr::OpIdx::OperandA: { @@ -2014,16 +2194,10 @@ struct LoadOpConversion llvm::dbgs() << "y offset: " << outer * repOuterStride + rep * repStride << "\n"; }); - if (useTileLoadLinearLayout) { - offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(layoutOffsetY)); - offsetX = b.i32_val(layoutOffsetX); - } else { - offsetY = - b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(outer * repOuterStride + rep * repStride)); - offsetX = b.i32_val(k * repKStride); - } + offsetY = + b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), + b.i32_val(outer * repOuterStride + rep * repStride)); + offsetX = b.i32_val(k * repKStride); } break; case DpasEncodingAttr::OpIdx::OperandB: { LLVM_DEBUG({ @@ -2031,16 +2205,10 @@ struct LoadOpConversion << outer * repOuterStride + rep * repStride << "\n"; llvm::dbgs() << "y offset: " << k * repKStride << "\n"; }); - if (useTileLoadLinearLayout) { - offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(layoutOffsetX)); - offsetY = b.i32_val(layoutOffsetY); - } else { - offsetX = - b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), - b.i32_val(outer * repOuterStride + rep * repStride)); - offsetY = b.i32_val(k * repKStride); - } + offsetX = + b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)), + b.i32_val(outer * repOuterStride + rep * repStride)); + offsetY = b.i32_val(k * repKStride); } break; case DpasEncodingAttr::OpIdx::OperandC: { llvm_unreachable("unexpected OpIdx::OperandC"); @@ -2186,7 +2354,9 @@ struct LoadOpConversion } } - Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Type llvmResultStructTy = typeConverter->convertType(dpasTensorType); + LLVM_DEBUG(llvm::dbgs() << "Packing load result in struct " + << llvmResultStructTy << "\n"); Value resultStruct = packLLElements(loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); @@ -2209,11 +2379,22 @@ struct LoadOpConversion Value mask = op.getMask(); Value llMask = adaptor.getMask(); + auto opType = op.getType(); + // TODO: Override the OpType since conversion is still happening during Load + // lowering. Once we materialize ConvertLayoutOp this can be removed. + auto tensorTy = dyn_cast(opType); + if (tensorTy && hasSubgroup2DBlockEncoding(tensorTy)) + opType = getDpasTypeFromCVTOp(op.getResult()); + // Determine the vectorization size - Type valueElemTy = - typeConverter->convertType(getElementTypeOrSelf(op.getType())); - unsigned numElems = getTotalElemsPerThread(op.getType()); + Type valueElemTy = typeConverter->convertType(getElementTypeOrSelf(opType)); + unsigned numElems = getTotalElemsPerThread(opType); unsigned vec = getVectorSize(ptr); + LLVM_DEBUG({ + llvm::dbgs() << "Vectorization for gather load:\n"; + llvm::dbgs() << "\t" << valueElemTy << " [" << numElems << "]\n"; + llvm::dbgs() << "\tvector size = " << vec << " for " << ptr << "\n"; + }); if (llMask) vec = std::min(vec, getMaskAlignment(mask)); @@ -2223,9 +2404,11 @@ struct LoadOpConversion if (isTensorPointerType(ptr.getType())) { // fallback to gather load. - auto tensorType = cast(op.getType()); + // make sure we use the modified opType from above, "seeing through" any + // post-subgroup 2d block encoding CVT. + auto blockPtrTensorType = cast(opType); std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr( - loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter, + loc, adaptor.getPtr(), blockPtrTensorType, valueElemTy, rewriter, op.getBoundaryCheck(), op.getPadding()); } else { Value other = op.getOther(); @@ -2370,7 +2553,7 @@ struct LoadOpConversion } } // end vec - Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Type llvmResultStructTy = typeConverter->convertType(opType); Value resultStruct = packLLElements(loc, typeConverter, loadedVals, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index b8cb96cfa0..bb32041127 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -5,6 +5,7 @@ add_triton_library(TritonIntelGPUTransforms DistributeToWarps.cpp MatchTargetSize.cpp MaterializeBlockPointer.cpp + OptimizeBlockIOEncoding.cpp OptimizeDotOperands.cpp OptimizeReductionLocality.cpp Pipeliner/MatmulLoopPipeline.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp new file mode 100644 index 0000000000..effd387df6 --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp @@ -0,0 +1,344 @@ +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "llvm/ADT/PriorityWorklist.h" + +namespace ttg = mlir::triton::gpu; +namespace ttgi = mlir::triton::gpu::intel; + +namespace mlir { +namespace triton { +namespace gpu::intel { + +#define DEBUG_TYPE "tritongpu-optimize-block-encoding" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +SmallVector getTiedArgs(Operation *op, int resultIdx) { + if (auto forOp = dyn_cast(op)) { + auto iterArg = forOp.getRegionIterArg(resultIdx); + auto result = forOp.getResult(resultIdx); + auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx); + auto initVal = forOp.getInitArgs()[resultIdx]; + return {iterArg, result, yieldVal, initVal}; + } else if (auto whileOp = dyn_cast(op)) { + auto iterArg = whileOp.getBeforeArguments()[resultIdx]; + auto result = whileOp.getResults()[resultIdx]; + auto yieldVal = + whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx); + auto initVal = whileOp.getOperands()[resultIdx]; + return {iterArg, result, iterArg, initVal}; + } else if (auto ifOp = dyn_cast(op)) { + SmallVector values; + for (auto &block : ifOp.getThenRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + for (auto &block : ifOp.getElseRegion().getBlocks()) { + auto terminator = block.getTerminator(); + if (isa(terminator)) + values.push_back(terminator->getOperands()[resultIdx]); + } + values.push_back(ifOp->getResults()[resultIdx]); + return values; + } + return {}; +} + +Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = cast(type); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); +} + +Type getNewPointerType(Type type, Attribute encoding) { + assert(isa(type) && "expected a ptr type!"); + auto oldPointerType = cast(type); + return PointerType::get(getNewType(oldPointerType.getPointeeType(), encoding), + oldPointerType.getAddressSpace()); +} + +struct EncodingInfo { + Attribute desiredEncoding; + bool requiresConvert = false; + + bool operator==(const EncodingInfo &other) const { + return desiredEncoding == other.desiredEncoding && + requiresConvert == other.requiresConvert; + } +}; + +/** + * The algorithm here takes inspiration from + * TritonNVIDIAGPU::OptimizeDescriptorEncoding. The idea is to iterate the + * def-use chain in both directions starting from the Load Op. We store the + * values that need to be updated along with the new encoding in the + * `valueToEncodingInfo` MapVector. After all value/encoding pairs have been + * determined, we update the encoding for each value, adding aa conversion to + * the existing Load Op result layout for users of the load. + */ +void rewriteTensorLayoutsForOp(Attribute encoding, Operation *op) { + auto loadOp = cast(op); + auto loadPtrType = cast(loadOp->getOperand(0).getType()); + auto addressSpace = loadPtrType.getAddressSpace(); + + llvm::MapVector, EncodingInfo> valueToEncodingInfo; + llvm::PriorityWorklist> worklist; + + auto updateEncoding = [&](ArrayRef ptrValues, EncodingInfo info) { + for (auto value : ptrValues) { + bool requiresConvert = llvm::any_of( + value.getUsers(), [](auto user) { return isa(user); }); + info.requiresConvert = requiresConvert; + + auto typedVal = cast>(value); + auto itr = valueToEncodingInfo.find(typedVal); + if (itr == valueToEncodingInfo.end()) { + LLVM_DEBUG(DBGS() << "Add encoding " << info.desiredEncoding + << " for value " << typedVal << "\n"); + valueToEncodingInfo[typedVal] = info; + worklist.insert(typedVal); + } else { + LLVM_DEBUG(DBGS() << "Found existing encoding info " + << itr->second.desiredEncoding << " for value " + << typedVal << ". Ensure new encoding " + << info.desiredEncoding << " matches.\n"); + assert(itr->second == info && "already visited encoding info for " + "value, expected them to be equal!"); + continue; + } + } + }; + + worklist.insert(cast>(loadOp->getOperand(0))); + + // 1. Starting from the Load Op, propagate encoding info up and down the + // def-use chain. + while (!worklist.empty()) { + auto crtValue = worklist.pop_back_val(); + + // Propagate to users + for (OpOperand &use : crtValue.getUses()) { + auto op = use.getOwner(); + if (isa(op)) { + auto offset = 3 * isa(op); + auto vals = getTiedArgs(op, use.getOperandNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } else if (isa(op)) { + auto vals = getTiedArgs(op->getParentOp(), use.getOperandNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + + // Propagate to defining ops + if (auto opResult = dyn_cast(crtValue)) { + auto definingOp = opResult.getOwner(); + if (isa(definingOp)) { + auto vals = getTiedArgs(definingOp, opResult.getResultNumber()); + updateEncoding(vals, EncodingInfo{encoding}); + } + } else if (auto blockArg = dyn_cast(crtValue)) { + auto parentOp = blockArg.getOwner()->getParentOp(); + if (isa(parentOp)) { + auto offset = isa(parentOp); + auto vals = getTiedArgs(parentOp, blockArg.getArgNumber() - offset); + updateEncoding(vals, EncodingInfo{encoding}); + } + } + } + + // 2. Update the type for each value in-place. Add a ConvertLayout Op after + // any loads which require conversion to the existing layout for the loaded + // value. + for (auto &[val, einfo] : valueToEncodingInfo) { + Attribute newEncoding = einfo.desiredEncoding; + LLVM_DEBUG(DBGS() << "Rewrite encoding to " << newEncoding << " for value " + << val << "\n"); + + PointerType oldType = val.getType(); + auto oldTensorTy = cast(oldType.getPointeeType()); + auto newTensorTy = RankedTensorType::get( + oldTensorTy.getShape(), oldTensorTy.getElementType(), newEncoding); + + val.setType(PointerType::get(newTensorTy, oldType.getAddressSpace())); + if (einfo.requiresConvert) { + for (auto user : val.getUsers()) { + if (auto loadOp = dyn_cast(user)) { + + OpBuilder builder(loadOp); + auto oldLoadType = loadOp.getType(); + Value result = loadOp.getResult(); + + builder.setInsertionPointAfter(loadOp); + auto cvt = builder.create(loadOp.getLoc(), + result.getType(), result); + LLVM_DEBUG(DBGS() << "Added convert Op:\n" + << cvt << " after Load Op:\n" + << loadOp << "\n"); + result.setType(newTensorTy); + + result.replaceAllUsesExcept(cvt.getResult(), cvt.getOperation()); + } + } + } + } +} + +} // namespace + +#define GEN_PASS_DEF_TRITONINTELGPUOPTIMIZEBLOCKIOENCODINGPASS +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" + +class TritonIntelGPUOptimizeBlockIOEncodingPass + : public impl::TritonIntelGPUOptimizeBlockIOEncodingPassBase< + TritonIntelGPUOptimizeBlockIOEncodingPass> { + + void getSubgroup2DBlockLayoutForOperand( + Value operand, DpasEncodingAttr dpasLayout, + llvm::MapVector &layoutMap) { + auto isCandidateLoad = [](Value v) -> LoadOp { + // Peel out the original cvt dot_op<..., #blocked> + // and any other potential cvt/trans ops + while (true) { + if (auto cvtOp = v.getDefiningOp()) { + v = cvtOp.getSrc(); + continue; + } + if (auto transOp = v.getDefiningOp()) { + v = transOp.getSrc(); + continue; + } + break; + } + return isa(v.getDefiningOp()) ? cast(v.getDefiningOp()) + : nullptr; + }; + + LoadOp loadOp = isCandidateLoad(operand); + if (!loadOp) + return; + + auto dotOperandType = cast(operand.getType()); + auto layout = ttg::toLinearEncoding(dotOperandType); + auto order = layout.getThreadOrder(); + auto rank = order.size(); + if (rank != 2) { + loadOp.emitWarning( + "Subgroup 2D Block Encoding layouts only support rank 2 operands."); + return; + } + + auto dotOperandEncoding = + cast(dotOperandType.getEncoding()); + // layout width is determined by the DPAS operand encoding width + const int kWidth = dotOperandEncoding.getKWidth(); + + Attribute blockIOAttr = + loadOp->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (!blockIOAttr) + return; + + const bool valueRowMajor = + getOrderForDotOperand(0, rank, /*kContig=*/true) == order; + const bool memoryRowMajor = + blockIOAttr == StringAttr::get(&getContext(), "row_major"); + const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; + LLVM_DEBUG({ + DBGS() << "Original layout: " << dotOperandEncoding << "\n"; + DBGS() << "\tvalueRowMajor = " << valueRowMajor << "\n"; + DBGS() << "\tmemoryRowMajor = " << memoryRowMajor << "\n"; + DBGS() << "\tisTransposeRequired = " << isTransposeRequired << "\n"; + }); + if (dotOperandEncoding.getOpIdx() == 0 && isTransposeRequired) { + LLVM_DEBUG(DBGS() << "Transposed 'A' operand does not yet support " + "Subgroup 2D Block Encoding layout.\n"); + return; + } + + // get the MakeTensorPtr Op for the load + Value ptr = loadOp.getPtr(); + if (!isTensorPointerType(ptr.getType())) { + // TODO: support tensor of pointer loads + LLVM_DEBUG(DBGS() << "Ptr\n" + << ptr << " for Load Op:\n" + << loadOp + << "\nincompatible with Subgroup 2D Block Layout.\n"); + return; + } + MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr); + assert(makeTensorPtrOp && + "expecting a tensor pointer parent to block io load " + "with tensor pointer type"); + + auto oldTensorPtrType = cast(makeTensorPtrOp.getType()); + auto oldTensorType = + cast(oldTensorPtrType.getPointeeType()); + // Note: we need the old layout to get the order for the load, but it is not + // clear the layout will always be Blocked. Is there a better way to get + // this info? + auto oldLayout = cast(oldTensorType.getEncoding()); + + auto CTALayout = getCTALayout(dpasLayout); + const unsigned elemSizeInBits = + oldTensorType.getElementType().getIntOrFloatBitWidth(); + + auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout( + cast(dotOperandEncoding), + oldTensorType.getShape(), memoryRowMajor, isTransposeRequired, + elemSizeInBits / 8, &getContext()); + SmallVector instrShape{tileParams[0], tileParams[1]}; + const unsigned vBlocks = tileParams[2]; + + auto subgroup2DBlockEncoding = Subgroup2DBlockEncodingAttr::get( + &getContext(), dpasLayout.getWarpsPerCTA(), CTALayout, instrShape, + tileParams[2], isTransposeRequired, + getOrderForDotOperand(dotOperandEncoding.getOpIdx(), /*rank*/ rank, + /*kContig*/ true), + kWidth, dpasLayout.getThreadsPerWarp()); + + LLVM_DEBUG(DBGS() << "Generated new encoding: " << subgroup2DBlockEncoding + << " for op : " << loadOp << "\n"); + + layoutMap[loadOp] = subgroup2DBlockEncoding; + } + +public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + // Step 1. Find all loads which are candidates for conversion to Subgroup 2D + // Block Encoding. To be a candidate load, a load must be consumed by a Dot + // Op and the load operand must be a block ptr (produced by a MakeTensorPtr + // Op). Currently we look for loads with the "block_io" attribute but we + // could consider moving that logic to this pass later. We place the load + // and the candidate encoding into the layout map for propagation in step 2 + llvm::MapVector layoutMap; + m.walk([&](DotOp dotOp) { + auto dotOpType = cast(dotOp.getResult().getType()); + auto dpasLayout = dyn_cast(dotOpType.getEncoding()); + if (!dpasLayout) + return; + + getSubgroup2DBlockLayoutForOperand(dotOp.getA(), dpasLayout, layoutMap); + getSubgroup2DBlockLayoutForOperand(dotOp.getB(), dpasLayout, layoutMap); + }); + + // Step 2. Rewrite MakeTensorPtr to use the new layout and propagate the + // change through the def-use chain, terminating at the Load Op. We add a + // ConvertLayout Op after the Load Op to convert back to the original + // layout. Subgroup2DBlockEncoding layouts will be chosen as anchor layouts + // in RemoveLayoutConversions, and a subsequent run of + // RemoveLayoutConversions after this pass cleans up intermediate layout + // conversions and removes the original Load Op encoding. + for (auto &kv : layoutMap) { + rewriteTensorLayoutsForOp(kv.second, kv.first); + } + } +}; + +} // namespace gpu::intel +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp index ee218d76b1..bcc8cbc0b2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp @@ -30,6 +30,8 @@ class TritonIntelGPUReduceDataDuplicationPass auto srcEncoding = srcType.getEncoding(); if (isa(srcEncoding)) return; + if (isa(srcEncoding)) + return; auto dstDotOp = dyn_cast(dstType.getEncoding()); if (!dstDotOp) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 785920770a..f2aa2e6d7a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -120,9 +120,22 @@ bool isExpensiveLoadOrStore(Operation *op) { if (isSingleValue(base)) return false; - // Loads that use a block pointer are expensive if they cannot be lowered to - // 2D block read operations. Temporarily leverage the - // "ttig.block_io" attribute to filter out inexpensive loads. + if (auto loadOp = dyn_cast(op)) { + // Subgroup2DBlockEncodingAttr loads are expensive, but loads without this + // encoding may still be expensive so we only return true if the encodng + // exists + if (auto tensorTy = dyn_cast(loadOp.getType())) + if (isa(tensorTy.getEncoding())) + return true; + } + + // The block ptr attribute identifies loads that are candidates for subgroup + // 2d block io operations. Loads with these attributes (and without the new + // subgroup 2d block encoding above) should have their layouts replaced with + // the layout from the expensive op (usually a dot op with DPAS encoding). The + // load result is convert to the expensive op layout during LLVM lowering. + // Note: the long term plan is to replace this path with the above subgroup 2d + // block encoding layout. Attribute blockIOAttr = op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); if (blockIOAttr) @@ -140,6 +153,18 @@ bool isExpensiveLoadOrStore(Operation *op) { return false; } +bool isBlockIONoOpConversion(RankedTensorType srcType, + RankedTensorType dstType) { + return hasSubgroup2DBlockEncoding(srcType) && hasDotDpasEncoding(dstType); +} + +bool hasSubgroup2DBlockEncoding(RankedTensorType tensorType) { + if (!tensorType.getEncoding()) + return false; + + return isa(tensorType.getEncoding()); +} + bool hasDotDpasEncoding(RankedTensorType tensorType) { if (!tensorType.getEncoding()) return false; diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index ae485a2c7b..1abab23024 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -115,6 +115,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createTritonIntelGPUReduceDataDuplication); ADD_PASS_WRAPPER_0("add_materialize_block_pointer", gpu::intel::createTritonIntelGPUMaterializeBlockPointer); + ADD_PASS_WRAPPER_0( + "add_optimize_block_load_encoding", + gpu::intel::createTritonIntelGPUOptimizeBlockIOEncodingPass); ADD_PASS_WRAPPER_0("add_optimize_reduction_locality", gpu::intel::createTritonIntelGPUOptimizeReductionLocality); ADD_PASS_WRAPPER_0("add_reduce_variable_liveness", diff --git a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp index 2e1eee2adf..70ecacf335 100644 --- a/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp +++ b/third_party/intel/unittest/Dialect/TritonIntelGPU/LinearLayoutConversionsTest.cpp @@ -21,9 +21,10 @@ class LinearLayoutConversionsTest : public ::testing::Test { // Create a Subgroup2DBlockEncoding layout based on a DPAS layout Subgroup2DBlockEncodingAttr - sdb(ArrayRef instrShape, unsigned numBlocks, unsigned kWidth, - ArrayRef warpsPerCTA, ArrayRef repCluster, - ArrayRef blockShape, unsigned opsPerChannel, unsigned opIdx) { + sdb(ArrayRef instrShape, unsigned numBlocks, bool isTransposed, + unsigned kWidth, ArrayRef warpsPerCTA, + ArrayRef repCluster, ArrayRef blockShape, + unsigned opsPerChannel, unsigned opIdx) { auto dpasLayout = DpasEncodingAttr::get( &ctx, /*repeatCount=*/8, /*systolicDepth=*/8, /*executionSize=*/16, opsPerChannel, warpsPerCTA, repCluster, @@ -35,7 +36,7 @@ class LinearLayoutConversionsTest : public ::testing::Test { CTALayoutAttr::get( &ctx, dpasLayout.getCTAsPerCGA(), // TODO: add to DpasLayout? dpasLayout.getCTASplitNum(), dpasLayout.getCTAOrder()), - instrShape, numBlocks, + instrShape, numBlocks, isTransposed, getOrderForDotOperand(opIdx, /*rank*/ 2, /*kContig*/ true), kWidth, dpasLayout.getThreadsPerWarp()); return layout; @@ -51,7 +52,8 @@ TEST_F(LinearLayoutConversionsTest, FP32_32x8x2_M256_N128_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 8}, /*numBlocks*/ 2, /*kWidth*/ 4, + sdb(/*instrShape*/ {32, 8}, /*numBlocks*/ 2, /*isTransposed*/ false, + /*kWidth*/ 4, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 1}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 1, /*opIdx*/ 0), /*kWidth*/ 4), @@ -67,7 +69,8 @@ TEST_F(LinearLayoutConversionsTest, FP32_32x16x1_M256_N128_K32_B) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {32, 128}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*kWidth*/ 4, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 4, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 1}, /*blockShape*/ {32, 128}, /*opsPerChannel*/ 1, /*opIdx*/ 1), /*kWidth*/ 4), @@ -83,7 +86,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x32x1_M256_N32_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 32}, /*numBlocks*/ 1, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 32}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, /*opIdx*/ 0), /*kWidth*/ 2), @@ -99,7 +103,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*blockShape*/ {256, 32}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*isTransposed*/ false, + /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, /*opIdx*/ 0), /*kWidth*/ 2), @@ -114,7 +119,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_A) { TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_B) { EXPECT_EQ(subgroup2DBlockToLinearLayout( /*shape*/ {32, 256}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 2, + /*isTransposed*/ false, /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {32, 256}, /*opsPerChannel*/ 2, /*opIdx*/ 1), @@ -131,7 +137,8 @@ TEST_F(LinearLayoutConversionsTest, FP16_32x16x2_M256_N32_K32_B) { TEST_F(LinearLayoutConversionsTest, FP16_16x16x2_M256_N32_K32_B) { EXPECT_EQ(subgroup2DBlockToLinearLayout( /*shape*/ {32, 256}, - sdb(/*instrShape*/ {16, 16}, /*numBlocks*/ 2, /*kWidth*/ 2, + sdb(/*instrShape*/ {16, 16}, /*numBlocks*/ 2, + /*isTransposed*/ false, /*kWidth*/ 2, /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, /*blockShape*/ {32, 256}, /*opsPerChannel*/ 2, /*opIdx*/ 1), @@ -145,11 +152,32 @@ TEST_F(LinearLayoutConversionsTest, FP16_16x16x2_M256_N32_K32_B) { {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, FP16_32x16x1_M256_N32_K32_TRANSPOSE_B) { + // Note that the instrShape is pre-transpose + EXPECT_EQ( + subgroup2DBlockToLinearLayout( + /*shape*/ {32, 256}, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ true, + /*kWidth*/ 2, + /*warpsPerCTA*/ {8, 4}, /*repCluster*/ {4, 2}, + /*blockShape*/ {256, 32}, /*opsPerChannel*/ 2, + /*opIdx*/ 1), + /*kWidth*/ 2), + LinearLayout( + {{S("register"), + {{0, 1}, {1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 128}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {0, 16}}}, + {S("warp"), {{0, 32}, {0, 64}, {0, 0}, {0, 0}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, I8_16x32x1_M64_N128_K32_A) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*shape*/ {64, 32}, - sdb(/*instrShape*/ {16, 32}, /*numBlocks*/ 1, /*kWidth*/ 1, + sdb(/*instrShape*/ {16, 32}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 1, /*warpsPerCTA*/ {4, 8}, /*repCluster*/ {2, 1}, /*blockShape*/ {64, 32}, /*opsPerChannel*/ 4, /*opIdx*/ 0), @@ -165,7 +193,8 @@ TEST_F(LinearLayoutConversionsTest, I8_32x32x1_M64_N128_K32_B) { EXPECT_EQ( subgroup2DBlockToLinearLayout( /*shape*/ {32, 128}, - sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*kWidth*/ 1, + sdb(/*instrShape*/ {32, 16}, /*numBlocks*/ 1, /*isTransposed*/ false, + /*kWidth*/ 1, /*warpsPerCTA*/ {4, 8}, /*repCluster*/ {2, 1}, /*blockShape*/ {32, 128}, /*opsPerChannel*/ 4, /*opIdx*/ 1),