diff --git a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir index eb951c0baa..63b0a3cb7a 100644 --- a/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir +++ b/test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir @@ -69,6 +69,204 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt // ----- +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2]}> +module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { + tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 4 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} { + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %15 = arith.muli %13, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %24 = tt.splat %15 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>%31 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %45 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<512> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %47 = arith.muli %45, %cst_2 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %48 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %49 = tt.broadcast %47 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %50 = tt.broadcast %48 : tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %51 = arith.addi %49, %50 : tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %52 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %53 = tt.addptr %52, %51 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %54 = arith.addi %arg5, %c63_i32 : i32 + %55 = arith.divsi %54, %c64_i32 : i32 + %56 = arith.muli %arg7, %c64_i32 : i32 + %57 = tt.splat %56 : i32 -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + cf.br ^bb1(%c0_i32, %53 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb1(%58: i32, %61: tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>): // 2 preds: ^bb0, ^bb2 + %62 = arith.cmpi slt, %58, %55 : i32 + cf.cond_br %62, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %63 = arith.muli %58, %c64_i32 : i32 + %64 = arith.subi %arg5, %63 : i32 + %69 = tt.splat %64 : i32 -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %70 = arith.cmpi slt, %45, %69 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %71 = tt.broadcast %70 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK-COUNT-32: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 4, v_blocks = 2, transpose = false, vnni_transform = false, cache_control = Default} + %72 = tt.load %61, %71, %cst_4 {ttig.block_io = "row_major"} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.addptr %61, %57 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %76 = arith.addi %58, %c1_i32 : i32 + cf.br ^bb1(%76, %75 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb3: // pred: ^bb1 + tt.return + } +} + +// ----- + +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 1]}> +module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { + tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} { + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %15 = arith.muli %13, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %24 = tt.splat %15 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>%31 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %45 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<512> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %47 = arith.muli %45, %cst_2 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %48 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %49 = tt.broadcast %47 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %50 = tt.broadcast %48 : tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %51 = arith.addi %49, %50 : tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %52 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %53 = tt.addptr %52, %51 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %54 = arith.addi %arg5, %c63_i32 : i32 + %55 = arith.divsi %54, %c64_i32 : i32 + %56 = arith.muli %arg7, %c64_i32 : i32 + %57 = tt.splat %56 : i32 -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + cf.br ^bb1(%c0_i32, %53 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb1(%58: i32, %61: tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>): // 2 preds: ^bb0, ^bb2 + %62 = arith.cmpi slt, %58, %55 : i32 + cf.cond_br %62, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %63 = arith.muli %58, %c64_i32 : i32 + %64 = arith.subi %arg5, %63 : i32 + %69 = tt.splat %64 : i32 -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %70 = arith.cmpi slt, %45, %69 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %71 = tt.broadcast %70 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK-COUNT-16: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 1, transpose = false, vnni_transform = true, cache_control = Default} + %72 = tt.load %61, %71, %cst_4 {ttig.block_io = "row_major"} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.addptr %61, %57 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %76 = arith.addi %58, %c1_i32 : i32 + cf.br ^bb1(%76, %75 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb3: // pred: ^bb1 + tt.return + } +} + +// ----- + +#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 4], repCluster = [4, 2]}> +module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.shared = 33280 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32} { + tt.func public @matmul_tensor_pointer_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: !llvm.ptr<3>) attributes {noinline = false} { + %c63_i32 = arith.constant 63 : i32 + %c255_i32 = arith.constant 255 : i32 + %c127_i32 = arith.constant 127 : i32 + %c1_i32 = arith.constant 1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %cst_1 = arith.constant dense<0> : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %15 = arith.muli %13, %c256_i32 : i32 + %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %24 = tt.splat %15 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %26 = arith.addi %24, %22 : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>>%31 = tt.splat %arg4 : i32 -> tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> + %45 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %cst_2 = arith.constant dense<512> : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %47 = arith.muli %45, %cst_2 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %48 = tt.expand_dims %26 {axis = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 0, parent = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>}>> -> tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %49 = tt.broadcast %47 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %50 = tt.broadcast %48 : tensor<1x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %51 = arith.addi %49, %50 : tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %52 = tt.splat %arg1 : !tt.ptr -> tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %53 = tt.addptr %52, %51 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %54 = arith.addi %arg5, %c63_i32 : i32 + %55 = arith.divsi %54, %c64_i32 : i32 + %56 = arith.muli %arg7, %c64_i32 : i32 + %57 = tt.splat %56 : i32 -> tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + cf.br ^bb1(%c0_i32, %53 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb1(%58: i32, %61: tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>): // 2 preds: ^bb0, ^bb2 + %62 = arith.cmpi slt, %58, %55 : i32 + cf.cond_br %62, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + %63 = arith.muli %58, %c64_i32 : i32 + %64 = arith.subi %arg5, %63 : i32 + %69 = tt.splat %64 : i32 -> tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %70 = arith.cmpi slt, %45, %69 : tensor<64x1xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %71 = tt.broadcast %70 : tensor<64x1xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x256xi1, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + // CHECK-COUNT-8: triton_gen.2Dblockload {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2, transpose = false, vnni_transform = true, cache_control = Default} + %72 = tt.load %61, %71, %cst_4 {ttig.block_io = "row_major"} : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %75 = tt.addptr %61, %57 : tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>, tensor<64x256xi32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %76 = arith.addi %58, %c1_i32 : i32 + cf.br ^bb1(%76, %75 : i32, tensor<64x256x!tt.ptr, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>) + ^bb3: // pred: ^bb1 + tt.return + } +} + +// ----- + #mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [2, 2]}> #mma_1 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1]}> #mma_2 = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 1], repCluster = [4, 2]}> @@ -193,41 +391,53 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-COUNT-2: llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_2886:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFORM_PTR:.*]] = llvm.inttoptr %[[VAL_2886]] : i64 to !llvm.ptr<1> // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_0:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8 - // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_0]], %[[CST0_1]]) + // CHECK: %[[CST0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_MASK:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_0]] : i1 to i8 + // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK]], %[[CST0_2]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 - // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 + // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 - // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFORM_PTR:.*]] = llvm.inttoptr %[[VAL_3046]] : i64 to !llvm.ptr<1> // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_1:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8 - // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_1]], %[[CST0_1]]) + // CHECK: %[[CST0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_MASK:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_32]] : i1 to i8 + // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK]], %[[CST0_2]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 - // CHECK: %[[BASE_Y_1:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 - // CHECK: %[[LOAD_1:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_1]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 + // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 - // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFORM_PTR:.*]] = llvm.inttoptr %[[VAL_3046]] : i64 to !llvm.ptr<1> // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_2:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8 - // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_2]], %[[CST0_1]]) + // CHECK: %[[CST0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_MASK:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_64]] : i1 to i8 + // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK]], %[[CST0_2]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 - // CHECK: %[[BASE_Y_2:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 - // CHECK: %[[LOAD_2:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_2]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 + // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 - // CHECK-COUNT-3: llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64 + // CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64 + // CHECK: %[[UNIFORM_PTR:.*]] = llvm.inttoptr %[[VAL_3046]] : i64 to !llvm.ptr<1> // CHECK: %[[CST0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[CST0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[TOP_LEFT_MASK_3:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8 - // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK_3]], %[[CST0_1]]) + // CHECK: %[[CST0_2:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[TOP_LEFT_MASK:.*]] = llvm.zext %[[TOP_LEFT_MASK_BOOL_96]] : i1 to i8 + // CHECK: %[[PRED:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[TOP_LEFT_MASK]], %[[CST0_2]]) // CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1 - // CHECK: %[[BASE_Y_3:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_0]], %[[BLOCK_SHAPE_Y]] : i1, i32 - // CHECK: %[[LOAD_3:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_3]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 + // CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32 + // CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2 %11 = tt.load %10, %a_mask, %a_other {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr, #mma> tt.return diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 4fc0178dae..fd05a6be5e 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -25,6 +25,17 @@ using namespace mlir::triton::gpu::intel; #define S(v) StringAttr::get(ctx, (v)) +#if defined(_MSC_VER) && !defined(__clang__) +// from https://gist.github.com/pps83/3210a2f980fd02bb2ba2e5a1fc4a2ef0 +#include + +static int __builtin_ctz(unsigned x) { + unsigned long r; + _BitScanForward(&r, x); + return static_cast(r); +} +#endif + namespace { Value maybeAnd(RewriterBase &rewriter, Location loc, Value a, Value b) { @@ -2526,6 +2537,19 @@ struct LoadOpToBlockIOConversion auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); + unsigned threadsPerWarp = + TritonGPUDialect::getThreadsPerWarp(op->getParentOfType()); + + StringAttr kRegister = S("register"); + assert(regPackedBases.has_value() && + "invalid register bases for packing elems."); + std::vector> bases(regPackedBases->size()); + llvm::transform(*regPackedBases, bases.begin(), + [&](int base) { return std::vector{base}; }); + LinearLayout regMapping({{kRegister, bases}}, + {{kRegister, llEncoding->getInDimSize(kRegister)}}, + /*requireSurjective=*/true); + // Get the LLVM values for pointers Value llPtr = adaptor.getPtr(); SmallVector ptrElems = unpackLLElements(loc, llPtr, rewriter); @@ -2557,14 +2581,56 @@ struct LoadOpToBlockIOConversion } // Check the constancy of the mask support to load the memory in 2D block. - if (!(maskConstancyHor >= (tileWidth * numPackedVals) && - maskConstancyVer >= tileHeight)) - return failure(); + if (!(maskConstancyHor >= (tileWidth * numPackedVals))) + return failure(); // The tileWidth and numPackedVals is not changeable + // for now. // Adjust vBlock to fit the constancy of mask. vBlocks = std::min(vBlocks, mlir::ceil(maskConstancyHor, tileWidth * numPackedVals)); assert(llvm::isPowerOf2_64(vBlocks) && "vBlocks has to be power of 2"); + + // Check the constancy of the mask support to load the memory in 2D block. + if (maskConstancyVer < tileHeight) { + unsigned minTileHeight = + mlir::ceil(threadsPerWarp, tileWidth); + if (maskConstancyVer < minTileHeight) + return failure(); + + unsigned numBasesForPackedVals = __builtin_ctz(numPackedVals); + unsigned numBasesForTileWidth = + __builtin_ctz(mlir::ceil(tileWidth, threadsPerWarp)); + unsigned numBasesForNewTileHeight = + __builtin_ctz(maskConstancyVer / minTileHeight); + unsigned numBasesForOldTileHeight = + __builtin_ctz(tileHeight / minTileHeight); + unsigned numBasesForVBlocks = __builtin_ctz(vBlocks); + + std::vector> rearrangeMap; + for (int i = 0; i < __builtin_ctz(numElems); ++i) { + rearrangeMap.emplace_back().push_back(1 << i); + } + + // Rotate the register bases of the adjusted part of tile height to the + // place after the vBlocks. + unsigned rotateStart = numBasesForPackedVals + numBasesForTileWidth + + numBasesForNewTileHeight; + unsigned rotateNum = + numBasesForOldTileHeight - numBasesForNewTileHeight; + std::rotate(rearrangeMap.begin() + rotateStart, + rearrangeMap.begin() + rotateStart + rotateNum, + rearrangeMap.begin() + rotateStart + rotateNum + + numBasesForVBlocks); + + LinearLayout rearrangeMapLL( + {{kRegister, rearrangeMap}}, + {{kRegister, regMapping.getInDimSize(kRegister)}}, + /*requireSurjective=*/true); + tileHeight = maskConstancyVer; + assert(((tileHeight - 1) & tileHeight) == 0 && + "the tileHeight has to be power of 2."); + regMapping = rearrangeMapLL.compose(regMapping); + } } // Get the LLVM values for `other` @@ -2595,8 +2661,6 @@ struct LoadOpToBlockIOConversion } } - unsigned threadsPerWarp = - TritonGPUDialect::getThreadsPerWarp(op->getParentOfType()); int64_t numElemsPerLoad = mlir::ceil( tileHeight * tileWidth * numPackedVals * vBlocks, (int)threadsPerWarp); unsigned numValuesPerLoad = mlir::ceil((int)numElemsPerLoad, numPackedVals); @@ -2616,20 +2680,6 @@ struct LoadOpToBlockIOConversion Value baseWidth = b.i32_val( std::max(64u, vBlocks * tileWidth * (packedElemSizeInBits / 8))); - StringAttr kRegister = str_attr("register"); - StringAttr kLane = str_attr("lane"); - StringAttr kWarp = str_attr("warp"); - StringAttr kBlock = str_attr("block"); - - assert(regPackedBases.has_value() && - "invalid register bases for packing elems."); - std::vector> bases(regPackedBases->size()); - llvm::transform(*regPackedBases, bases.begin(), - [](int base) { return std::vector{base}; }); - LinearLayout regMapping({{kRegister, bases}}, - {{kRegister, llEncoding->getInDimSize(kRegister)}}, - /*requireSurjective=*/true); - bool useVNNIFormat = false; Type packedDPASOperandType; if (hasDotDpasEncoding(tensorType)) {