diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 2bc26e4ecd..91795281f5 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -3488,3 +3488,321 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return %12, %9#1 : tensor<4x1xi64, #blocked>, i32 } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK: test_5251_1 + tt.func public @test_5251_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<2147483647> : tensor<1x4xi32, #blocked> + %cst_0 = arith.constant dense<0x7F800000> : tensor<1x4xf32, #blocked> + %0 = tt.splat %arg3 : i32 -> tensor<1x4xi32, #blocked> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<1x4x!tt.ptr, #blocked> + %2:2 = scf.for %arg4 = %c0_i32 to %arg3 step %c4_i32 iter_args(%arg5 = %cst_0, %arg6 = %cst) -> (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) : i32 { + // CHECK: [[RES:%.*]] = scf.for + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield + %12 = tt.splat %arg4 : i32 -> tensor<1x4xi32, #blocked> + %13 = arith.cmpi slt, %12, %0 : tensor<1x4xi32, #blocked> + %14 = ttg.convert_layout %1 : tensor<1x4x!tt.ptr, #blocked> -> tensor<1x4x!tt.ptr, #blocked> + %15 = tt.load %14 : tensor<1x4x!tt.ptr, #blocked> + %16 = arith.cmpi slt, %arg6, %12 : tensor<1x4xi32, #blocked> + %17 = arith.select %16, %arg5, %15 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %18 = arith.select %13, %17, %arg5 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + scf.yield %18, %arg6 : tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked> + } + // CHECK: [[RED:%.*]]:2 = "tt.reduce"([[RES]], %cst) + %3:2 = "tt.reduce"(%2#0, %2#1) <{axis = 1 : i32}> ({ + ^bb0(%arg4: f32, %arg5: i32, %arg6: f32, %arg7: i32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %12 = arith.cmpf olt, %arg4, %arg6 : f32 + %13 = arith.select %12, %arg4, %arg6 : f32 + %14 = arith.select %12, %arg5, %arg7 : i32 + tt.reduce.return %13, %14 : f32, i32 + }) : (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED]]#1 + %4 = ttg.convert_layout %3#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %7 = ttg.convert_layout %6 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked3> + %8 = tt.splat %arg2 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked3> + %9 = arith.extsi %7 : tensor<1x1xi32, #blocked3> to tensor<1x1xi64, #blocked3> + %10 = ttg.convert_layout %8 : tensor<1x1x!tt.ptr, #blocked3> -> tensor<1x1x!tt.ptr, #blocked3> + %11 = ttg.convert_layout %9 : tensor<1x1xi64, #blocked3> -> tensor<1x1xi64, #blocked3> + tt.store %10, %11 : tensor<1x1x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32} { + // CHECK: test_5251_2 + tt.func public @test_5251_2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32) { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<2147483647> : tensor<1x4xi32, #blocked> + %cst_0 = arith.constant dense<0x7F800000> : tensor<1x4xf32, #blocked> + %0 = tt.splat %arg2 : i32 -> tensor<1x4xi32, #blocked> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<1x4x!tt.ptr, #blocked> + %2:2 = scf.for %arg3 = %c0_i32 to %arg2 step %c4_i32 iter_args(%arg4 = %cst_0, %arg5 = %cst) -> (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) : i32 { + // CHECK: [[RES:%.*]] = scf.for + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield + %14 = tt.splat %arg3 : i32 -> tensor<1x4xi32, #blocked> + %15 = arith.cmpi slt, %14, %0 : tensor<1x4xi32, #blocked> + %16 = ttg.convert_layout %1 : tensor<1x4x!tt.ptr, #blocked> -> tensor<1x4x!tt.ptr, #blocked> + %17 = ttg.convert_layout %16 : tensor<1x4x!tt.ptr, #blocked> -> tensor<1x4x!tt.ptr, #blocked> + %18 = tt.load %17 : tensor<1x4x!tt.ptr, #blocked> + %19 = arith.cmpi slt, %arg5, %14 : tensor<1x4xi32, #blocked> + %20 = arith.select %19, %arg4, %18 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %21 = arith.select %15, %20, %arg4 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + scf.yield %21, %arg5 : tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked> + } + // CHECK: [[RED:%.*]]:2 = "tt.reduce"([[RES]], %cst) + %3:2 = "tt.reduce"(%2#0, %2#1) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %14 = arith.cmpf olt, %arg3, %arg5 : f32 + %15 = arith.select %14, %arg3, %arg5 : f32 + %16 = arith.select %14, %arg4, %arg6 : i32 + tt.reduce.return %15, %16 : f32, i32 + }) : (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED]]#1 + %4 = ttg.convert_layout %3#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %5 = ttg.convert_layout %4 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> + %6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<1x1xi32, #blocked2> + %7 = ttg.convert_layout %6 : tensor<1x1xi32, #blocked2> -> tensor<1x1xi32, #blocked3> + %8 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked3> + %9 = arith.extsi %7 : tensor<1x1xi32, #blocked3> to tensor<1x1xi64, #blocked3> + %10 = ttg.convert_layout %8 : tensor<1x1x!tt.ptr, #blocked3> -> tensor<1x1x!tt.ptr, #blocked3> + %11 = ttg.convert_layout %9 : tensor<1x1xi64, #blocked3> -> tensor<1x1xi64, #blocked3> + %12 = ttg.convert_layout %10 : tensor<1x1x!tt.ptr, #blocked3> -> tensor<1x1x!tt.ptr, #blocked3> + %13 = ttg.convert_layout %11 : tensor<1x1xi64, #blocked3> -> tensor<1x1xi64, #blocked3> + tt.store %12, %13 : tensor<1x1x!tt.ptr, #blocked3> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [0, 1]}> +#blocked3 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blocked4 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block} { + // CHECK: test_5251_3 + tt.func public @test_5251_3(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: i32) attributes {noinline = false} { + // CHECK-NOT: ttg.convert_layout + %true = arith.constant true + %cst = arith.constant dense : tensor<1x4xi1, #blocked> + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_0 = arith.constant dense<2147483647> : tensor<1x4xi32, #blocked> + %cst_1 = arith.constant dense<0x7F800000> : tensor<1x4xf32, #blocked> + %cst_2 = arith.constant dense<0xFF800000> : tensor<1x4xf32, #blocked> + %cst_3 = arith.constant dense<0.000000e+00> : tensor<1x4xf32, #blocked> + %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<4xi32, #blocked1> -> tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> + %2 = tt.expand_dims %1 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x4xi32, #blocked2> + %3 = ttg.convert_layout %2 : tensor<1x4xi32, #blocked2> -> tensor<1x4xi32, #blocked> + %4 = tt.splat %arg6 : i32 -> tensor<1x4xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<1x4x!tt.ptr, #blocked> + %6:7 = scf.for %arg7 = %c0_i32 to %arg6 step %c4_i32 iter_args(%arg8 = %cst_3, %arg9 = %cst_2, %arg10 = %cst_1, %arg11 = %cst_2, %arg12 = %cst_0, %arg13 = %cst_1, %arg14 = %cst_0) -> (tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) : i32 { + // CHECK: [[RES:%.*]]:7 = scf.for + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield + %49 = tt.splat %arg7 : i32 -> tensor<1x4xi32, #blocked> + %50 = arith.addi %49, %3 : tensor<1x4xi32, #blocked> + %51 = arith.cmpi slt, %50, %4 : tensor<1x4xi32, #blocked> + %52 = tt.addptr %5, %50 : tensor<1x4x!tt.ptr, #blocked>, tensor<1x4xi32, #blocked> + %53 = ttg.convert_layout %52 : tensor<1x4x!tt.ptr, #blocked> -> tensor<1x4x!tt.ptr, #blocked> + %54 = ttg.convert_layout %51 : tensor<1x4xi1, #blocked> -> tensor<1x4xi1, #blocked> + %55 = ttg.convert_layout %cst_3 : tensor<1x4xf32, #blocked> -> tensor<1x4xf32, #blocked> + %56 = tt.load %53, %54, %55 evictionPolicy = evict_first : tensor<1x4x!tt.ptr, #blocked> + %57 = arith.addf %arg8, %56 : tensor<1x4xf32, #blocked> + %58 = arith.select %51, %57, %arg8 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %59 = arith.cmpf ogt, %arg9, %56 : tensor<1x4xf32, #blocked> + %60 = arith.cmpf une, %arg9, %arg9 : tensor<1x4xf32, #blocked> + %61 = arith.ori %59, %60 : tensor<1x4xi1, #blocked> + %62 = arith.select %61, %arg9, %56 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %63 = arith.select %51, %62, %arg9 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %64 = arith.cmpf olt, %arg10, %56 : tensor<1x4xf32, #blocked> + %65 = arith.cmpf une, %arg10, %arg10 : tensor<1x4xf32, #blocked> + %66 = arith.ori %64, %65 : tensor<1x4xi1, #blocked> + %67 = arith.select %66, %arg10, %56 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %68 = arith.select %51, %67, %arg10 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %69 = arith.cmpf ogt, %arg11, %56 : tensor<1x4xf32, #blocked> + %70 = arith.cmpf oeq, %arg11, %56 : tensor<1x4xf32, #blocked> + %71 = arith.cmpf une, %arg11, %arg11 : tensor<1x4xf32, #blocked> + %72 = arith.cmpf une, %56, %56 : tensor<1x4xf32, #blocked> + %73 = arith.xori %72, %cst : tensor<1x4xi1, #blocked> + %74 = arith.andi %71, %73 : tensor<1x4xi1, #blocked> + %75 = arith.ori %69, %74 : tensor<1x4xi1, #blocked> + %76 = arith.andi %71, %72 : tensor<1x4xi1, #blocked> + %77 = arith.ori %70, %76 : tensor<1x4xi1, #blocked> + %78 = arith.cmpi slt, %arg12, %50 : tensor<1x4xi32, #blocked> + %79 = arith.andi %77, %78 : tensor<1x4xi1, #blocked> + %80 = arith.ori %75, %79 : tensor<1x4xi1, #blocked> + %81 = arith.select %80, %arg11, %56 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %82 = arith.select %80, %arg12, %50 : tensor<1x4xi1, #blocked>, tensor<1x4xi32, #blocked> + %83 = arith.select %51, %81, %arg11 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %84 = arith.select %51, %82, %arg12 : tensor<1x4xi1, #blocked>, tensor<1x4xi32, #blocked> + %85 = arith.cmpf olt, %arg13, %56 : tensor<1x4xf32, #blocked> + %86 = arith.cmpf oeq, %arg13, %56 : tensor<1x4xf32, #blocked> + %87 = arith.cmpf une, %arg13, %arg13 : tensor<1x4xf32, #blocked> + %88 = arith.andi %87, %73 : tensor<1x4xi1, #blocked> + %89 = arith.ori %85, %88 : tensor<1x4xi1, #blocked> + %90 = arith.andi %87, %72 : tensor<1x4xi1, #blocked> + %91 = arith.ori %86, %90 : tensor<1x4xi1, #blocked> + %92 = arith.cmpi slt, %arg14, %50 : tensor<1x4xi32, #blocked> + %93 = arith.andi %91, %92 : tensor<1x4xi1, #blocked> + %94 = arith.ori %89, %93 : tensor<1x4xi1, #blocked> + %95 = arith.select %94, %arg13, %56 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %96 = arith.select %94, %arg14, %50 : tensor<1x4xi1, #blocked>, tensor<1x4xi32, #blocked> + %97 = arith.select %51, %95, %arg13 : tensor<1x4xi1, #blocked>, tensor<1x4xf32, #blocked> + %98 = arith.select %51, %96, %arg14 : tensor<1x4xi1, #blocked>, tensor<1x4xi32, #blocked> + scf.yield %58, %63, %68, %83, %84, %97, %98 : tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>, tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked> + } + // CHECK: [[RED0:%.*]] = "tt.reduce"([[RES]]#0) + %7 = "tt.reduce"(%6#0) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: f32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.addf %arg7, %arg8 : f32 + tt.reduce.return %49 : f32 + }) : (tensor<1x4xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED0]] + %8 = ttg.convert_layout %7 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %9 = ttg.convert_layout %8 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xf32, #blocked3> + %11 = ttg.convert_layout %10 : tensor<1x1xf32, #blocked3> -> tensor<1x1xf32, #blocked4> + // CHECK: [[RED1:%.*]] = "tt.reduce"([[RES]]#1) + %12 = "tt.reduce"(%6#1) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: f32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.cmpf ogt, %arg7, %arg8 : f32 + %50 = arith.cmpf une, %arg7, %arg7 : f32 + %51 = arith.ori %49, %50 : i1 + %52 = arith.select %51, %arg7, %arg8 : f32 + tt.reduce.return %52 : f32 + }) : (tensor<1x4xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED1]] + %13 = ttg.convert_layout %12 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %14 = ttg.convert_layout %13 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %15 = tt.expand_dims %14 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xf32, #blocked3> + %16 = ttg.convert_layout %15 : tensor<1x1xf32, #blocked3> -> tensor<1x1xf32, #blocked4> + // CHECK: [[RED2:%.*]] = "tt.reduce"([[RES]]#2) + %17 = "tt.reduce"(%6#2) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: f32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.cmpf olt, %arg7, %arg8 : f32 + %50 = arith.cmpf une, %arg7, %arg7 : f32 + %51 = arith.ori %49, %50 : i1 + %52 = arith.select %51, %arg7, %arg8 : f32 + tt.reduce.return %52 : f32 + }) : (tensor<1x4xf32, #blocked>) -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED2]] + %18 = ttg.convert_layout %17 : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xf32, #blocked1> + %19 = ttg.convert_layout %18 : tensor<1xf32, #blocked1> -> tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %20 = tt.expand_dims %19 {axis = 1 : i32} : tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xf32, #blocked3> + %21 = ttg.convert_layout %20 : tensor<1x1xf32, #blocked3> -> tensor<1x1xf32, #blocked4> + // CHECK: [[RED3:%.*]]:2 = "tt.reduce"([[RES]]#3, [[RES]]#4) + %22:2 = "tt.reduce"(%6#3, %6#4) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.cmpf ogt, %arg7, %arg9 : f32 + %50 = arith.cmpf oeq, %arg7, %arg9 : f32 + %51 = arith.cmpf une, %arg7, %arg7 : f32 + %52 = arith.cmpf une, %arg9, %arg9 : f32 + %53 = arith.xori %52, %true : i1 + %54 = arith.andi %51, %53 : i1 + %55 = arith.ori %49, %54 : i1 + %56 = arith.andi %51, %52 : i1 + %57 = arith.ori %50, %56 : i1 + %58 = arith.cmpi slt, %arg8, %arg10 : i32 + %59 = arith.andi %57, %58 : i1 + %60 = arith.ori %55, %59 : i1 + %61 = arith.select %60, %arg7, %arg9 : f32 + %62 = arith.select %60, %arg8, %arg10 : i32 + tt.reduce.return %61, %62 : f32, i32 + }) : (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED3]]#1 + %23 = ttg.convert_layout %22#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %24 = ttg.convert_layout %23 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %25 = tt.expand_dims %24 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xi32, #blocked3> + %26 = ttg.convert_layout %25 : tensor<1x1xi32, #blocked3> -> tensor<1x1xi32, #blocked4> + // CHECK: [[RED4:%.*]]:2 = "tt.reduce"([[RES]]#5, [[RES]]#6) + %27:2 = "tt.reduce"(%6#5, %6#6) <{axis = 1 : i32}> ({ + ^bb0(%arg7: f32, %arg8: i32, %arg9: f32, %arg10: i32): + // CHECK-NOT: ttg.convert_layout + // CHECK tt.reduce.return + %49 = arith.cmpf olt, %arg7, %arg9 : f32 + %50 = arith.cmpf oeq, %arg7, %arg9 : f32 + %51 = arith.cmpf une, %arg7, %arg7 : f32 + %52 = arith.cmpf une, %arg9, %arg9 : f32 + %53 = arith.xori %52, %true : i1 + %54 = arith.andi %51, %53 : i1 + %55 = arith.ori %49, %54 : i1 + %56 = arith.andi %51, %52 : i1 + %57 = arith.ori %50, %56 : i1 + %58 = arith.cmpi slt, %arg8, %arg10 : i32 + %59 = arith.andi %57, %58 : i1 + %60 = arith.ori %55, %59 : i1 + %61 = arith.select %60, %arg7, %arg9 : f32 + %62 = arith.select %60, %arg8, %arg10 : i32 + tt.reduce.return %61, %62 : f32, i32 + }) : (tensor<1x4xf32, #blocked>, tensor<1x4xi32, #blocked>) -> (tensor<1xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-NOT: ttg.convert_layout + // CHECK: tt.expand_dims [[RED4]]#1 + %28 = ttg.convert_layout %27#1 : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<1xi32, #blocked1> + %29 = ttg.convert_layout %28 : tensor<1xi32, #blocked1> -> tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + %30 = tt.expand_dims %29 {axis = 1 : i32} : tensor<1xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> -> tensor<1x1xi32, #blocked3> + %31 = ttg.convert_layout %30 : tensor<1x1xi32, #blocked3> -> tensor<1x1xi32, #blocked4> + %32 = tt.splat %arg1 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %33 = ttg.convert_layout %32 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %34 = ttg.convert_layout %11 : tensor<1x1xf32, #blocked4> -> tensor<1x1xf32, #blocked4> + tt.store %33, %34 : tensor<1x1x!tt.ptr, #blocked4> + %35 = tt.splat %arg2 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %36 = ttg.convert_layout %35 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %37 = ttg.convert_layout %16 : tensor<1x1xf32, #blocked4> -> tensor<1x1xf32, #blocked4> + tt.store %36, %37 : tensor<1x1x!tt.ptr, #blocked4> + %38 = tt.splat %arg3 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %39 = ttg.convert_layout %38 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %40 = ttg.convert_layout %21 : tensor<1x1xf32, #blocked4> -> tensor<1x1xf32, #blocked4> + tt.store %39, %40 : tensor<1x1x!tt.ptr, #blocked4> + %41 = tt.splat %arg4 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %42 = arith.extsi %26 : tensor<1x1xi32, #blocked4> to tensor<1x1xi64, #blocked4> + %43 = ttg.convert_layout %41 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %44 = ttg.convert_layout %42 : tensor<1x1xi64, #blocked4> -> tensor<1x1xi64, #blocked4> + tt.store %43, %44 : tensor<1x1x!tt.ptr, #blocked4> + %45 = tt.splat %arg5 : !tt.ptr -> tensor<1x1x!tt.ptr, #blocked4> + %46 = arith.extsi %31 : tensor<1x1xi32, #blocked4> to tensor<1x1xi64, #blocked4> + %47 = ttg.convert_layout %45 : tensor<1x1x!tt.ptr, #blocked4> -> tensor<1x1x!tt.ptr, #blocked4> + %48 = ttg.convert_layout %46 : tensor<1x1xi64, #blocked4> -> tensor<1x1xi64, #blocked4> + tt.store %47, %48 : tensor<1x1x!tt.ptr, #blocked4> + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 89cbbdc462..bb7a7afb33 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -184,7 +184,9 @@ class LayoutRematerialization { void LayoutRematerialization::addRematValue(Value old, Attribute encoding, Value newV) { - LDBG("addRematValue " << old << " encoding " << encoding << " " << newV); + LDBG("addRematValue for: " << old); + LDBG(" encoding: " << encoding); + LDBG(" new: " << newV); rematMapping[{old, encoding}] = newV; mappedValues[old] = encoding; } @@ -1133,9 +1135,23 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, } if (enableForLoopSupport) if (auto forOp = v.getDefiningOp()) { - unsigned operandIdx = cast(v).getResultNumber(); + unsigned opIdx = cast(v).getResultNumber(); auto yieldOp = forOp.getBody()->getTerminator(); - yieldOperandsMap[yieldOp].push_back(operandIdx); + if (!yieldOperandsMap.contains(yieldOp)) { + yieldOperandsMap[yieldOp].push_back(opIdx); + LLVM_DEBUG({ + llvm::errs() << "1a. pushing " << opIdx + << " in yieldOperandMap\n "; + }); + } else if (llvm::none_of( + yieldOperandsMap[yieldOp], + [&](unsigned idx) { return idx == opIdx; })) { + yieldOperandsMap[yieldOp].push_back(opIdx); + LLVM_DEBUG({ + llvm::errs() << "1b. pushing " << opIdx + << " in yieldOperandMap\n "; + }); + } opsToRewrite.insert(yieldOp); } } else { @@ -1144,8 +1160,22 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, if (auto loopOp = cast(parentOp)) { opsToRewrite.insert(loopOp.getOperation()); OpOperand *operand = loopOp.getTiedLoopYieldedValue(blockArg); + unsigned opIdx = operand->getOperandNumber(); auto yieldOp = blockArg.getOwner()->getTerminator(); - yieldOperandsMap[yieldOp].push_back(operand->getOperandNumber()); + if (!yieldOperandsMap.contains(yieldOp)) { + yieldOperandsMap[yieldOp].push_back(opIdx); + LLVM_DEBUG({ + llvm::errs() << "2a. pushing " << operand->getOperandNumber() + << " in yieldOperandMap\n"; + }); + } else if (llvm::none_of(yieldOperandsMap[yieldOp], + [&](unsigned idx) { return idx == opIdx; })) { + yieldOperandsMap[yieldOp].push_back(opIdx); + LLVM_DEBUG({ + llvm::errs() << "2b. pushing " << operand->getOperandNumber() + << " in yieldOperandMap\n"; + }); + } opsToRewrite.insert(yieldOp); } } @@ -1153,6 +1183,20 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, slice.set_subtract(valuesWithExistingRemat); opsToRewrite = multiRootTopologicalSort(opsToRewrite); + LLVM_DEBUG({ + llvm::errs() << "opsToRewrite:\n"; + for (Operation *op : opsToRewrite) { + llvm::errs().indent(2) << "(" << op << "): "; + op->dumpPretty(); + } + llvm::errs() << "yieldOperandsMap:\n"; + for (auto entry : yieldOperandsMap) { + llvm::errs() << *entry.first << " -> \n"; + for (int opx : entry.second) + llvm::errs().indent(2) << opx << "\n"; + } + }); + // replaceAllUsesWith calls delayed until after initial rewrite. // This is required for slice.count(value) to work mid rewrite. SmallVector> replacements; @@ -1160,6 +1204,13 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, SmallVector deadOps; IRRewriter builder(slice.begin()->getContext()); for (Operation *op : opsToRewrite) { + LLVM_DEBUG({ + llvm::errs() << "Processing:\n"; + llvm::errs().indent(2) << "(" << op << "): "; + op->dumpPretty(); + llvm::errs() << "\n"; + }); + if (auto forOp = dyn_cast(op)) { SmallVector newOperands; if (enableForLoopSupport) { @@ -1171,8 +1222,13 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, std::sort(operandsToRewrite.begin(), operandsToRewrite.end()); for (int operandIdx : operandsToRewrite) { Value yieldOperand = yieldOp.getOperand(operandIdx); - if (mapping.contains(yieldOperand)) + if (mapping.contains(yieldOperand)) { newOperands.push_back(mapping.lookup(yieldOperand)); + LLVM_DEBUG({ + llvm::errs() << "YieldOperand: " << yieldOperand + << " is mapped, adding new init to for loop\n"; + }); + } } } @@ -1185,18 +1241,47 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, forOp.getTiedLoopResult(&initVal).getResultNumber(), forOp.getInitArgs().size() + newOperands.size())); newOperands.push_back(mapping.lookup(initVal.get())); + LLVM_DEBUG({ + llvm::errs() << "initVal: " << initVal.get() + << " is mapped, adding new init to for loop\n "; + }); } } + + if (newOperands.empty()) + continue; + // Create a new for loop with the new operands. scf::ForOp newForOp = replaceForOpWithNewSignature( builder, forOp, newOperands, replacements); if (enableForLoopSupport) { + LLVM_DEBUG({ + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "newForOp (" << &newForOp << "): "; + newForOp->dumpPretty(); + llvm::errs() << "\n"; + }); + // Add rematerializations for loop results in the slice. unsigned oldIdx = 0; unsigned newIdx = forOp.getNumResults(); for (auto res : forOp.getResults()) { if (slice.count(res)) { + if (newIdx >= newForOp.getNumResults()) + break; + + LLVM_DEBUG({ + llvm::errs() << "oldIdx: " << oldIdx << "\n"; + llvm::errs() << "newIdx: " << newIdx << "\n"; + }); + Value oldRes = forOp.getResult(oldIdx); + Value newRes = newForOp.getResult(newIdx); + LLVM_DEBUG({ + llvm::errs() << "oldRes: " << oldRes << "\n"; + llvm::errs() << "newRes: " << newRes << "\n"; + }); + mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx)); addRematValue(forOp.getResult(oldIdx), layout[res], newForOp.getResult(newIdx)); @@ -1214,9 +1299,13 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, mapping.map(loopBody.getArgument(m.first + numIndVars), loopBody.getArgument(m.second + numIndVars)); LLVM_DEBUG({ - DBGS() << "mapping forOp " - << loopBody.getArgument(m.first + numIndVars) << " to " - << loopBody.getArgument(m.second + numIndVars) << '\n'; + DBGS() << "mapping forOp "; + loopBody.getArgument(m.first + numIndVars) + .printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << " to "; + loopBody.getArgument(m.second + numIndVars) + .printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << '\n'; }); // The result is not in the layout/slice, the argument is. Value oldArg = loopBody.getArgument(m.first + numIndVars); @@ -1274,7 +1363,103 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, for (int operandIdx : operandsToRewrite) { yieldOperands.push_back(mapping.lookup(yieldOp.getOperand(operandIdx))); } - builder.create(op->getLoc(), yieldOperands); + [[maybe_unused]] auto newYieldOp = + builder.create(op->getLoc(), yieldOperands); + + auto parentOp = newYieldOp->getParentOp(); + LLVM_DEBUG({ + unsigned numYieldArgsAdded = + newYieldOp.getNumOperands() - yieldOp.getNumOperands(); + llvm::errs() << "Added " << numYieldArgsAdded + << " operands to the loop yield\n"; + llvm::errs() << "newYieldOp:" << newYieldOp << "\n"; + llvm::errs() << "parentOp:"; + parentOp->dumpPretty(); + llvm::errs() << "\n"; + }); + +#if 1 + // Fixup the init argument list of the parent loop if necessary. + if (auto forOp = dyn_cast(parentOp)) { + unsigned numIterArgs = forOp.getRegionIterArgs().size(); + unsigned numYieldOperands = newYieldOp->getNumOperands(); + assert(numIterArgs <= numYieldOperands); + + if (numIterArgs < numYieldOperands) { + // We have more yield operands that loop initialization arguments. + // Create new "dummy" initialization arguments for loop. + SmallVector newOperands; + for (unsigned idx = numIterArgs; idx < numYieldOperands; ++idx) { + Value operand = newYieldOp->getOperand(idx); + Type operandTy = operand.getType(); + auto insertPt = builder.saveInsertionPoint(); + builder.setInsertionPoint(forOp->getPrevNode()); + auto constantOp = builder.create( + builder.getUnknownLoc(), operandTy, + builder.getZeroAttr(operandTy)); + builder.restoreInsertionPoint(insertPt); + // llvm::errs() << "constantOp: " << *constantOp << "\n"; + newOperands.push_back(constantOp); + ++idx; + } + + if (newOperands.empty()) + continue; + + scf::ForOp newForOp = replaceForOpWithNewSignature( + builder, forOp, newOperands, replacements); + + LLVM_DEBUG({ + unsigned numArgsAdded = + newForOp.getNumResults() - forOp.getNumResults(); + llvm::errs() << "Added " << numArgsAdded + << " arguments to the loop\n"; + llvm::errs() << "newForOp (" << &newForOp << "): "; + newForOp->dumpPretty(); + llvm::errs() << "\n"; + }); + + deadOps.push_back(forOp.getOperation()); + + // Add rematerializations for loop results in the slice. + if (newForOp->getNumResults() > forOp.getNumResults()) { + unsigned oldIdx = 0; + unsigned newIdx = forOp.getNumResults(); + for (auto res : forOp.getResults()) { + if (newIdx >= newForOp.getNumResults()) + break; + + LLVM_DEBUG({ + llvm::errs() << "at line: " << __LINE__ << "\n"; + llvm::errs() << "res: " << res << "\n"; + }); + + if (slice.count(res)) { + LLVM_DEBUG({ + llvm::errs() << "oldIdx: " << oldIdx << "\n"; + llvm::errs() << "newIdx: " << newIdx << "\n"; + }); + Value oldRes = forOp.getResult(oldIdx); + Value newRes = newForOp.getResult(newIdx); + + LLVM_DEBUG({ + llvm::errs() << "oldRes: " << oldRes << "\n"; + llvm::errs() << "newRes: " << newRes << "\n"; + }); + + mapping.map(forOp.getResult(oldIdx), + newForOp.getResult(newIdx)); + addRematValue(forOp.getResult(oldIdx), layout[res], + newForOp.getResult(newIdx)); + ++newIdx; + } + ++oldIdx; + } + } + } + } +#endif + op->erase(); continue; } @@ -1289,6 +1474,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, cvt.getResult()); continue; } + Operation *newOp = builder.clone(*op, mapping); for (auto [old, newV] : llvm::zip(op->getResults(), newOp->getResults())) { auto it = layout.find(old); @@ -1321,6 +1507,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, for (Operation *op : deadOps) opToDelete.insert(op); + + LLVM_DEBUG({ + auto mod = convertOp->getParentOfType(); + llvm::errs() << "rewriteSlice DONE:\n"; + mod->dump(); + }); } void LayoutRematerialization::rewriteSlice(SetVector &slice, @@ -1595,7 +1787,8 @@ void LayoutRematerialization::backwardRematerialization( if (!helper.isAssociative()) { // We shouldn't rematerize a no associative reduce op if it has multiple // use chain. - LDBG(" skipped rematerialization due to non-associative reduce in the " + LDBG(" skipped rematerialization due to non-associative reduce in " + "the " "slice"); return; } @@ -1859,7 +2052,7 @@ void LayoutRematerialization::hoistConvertIntoConditionals( return; // These are the conditional edges above which conversions should be hoisted. - // The value represents the `scf.if` op result and the operand represents the + // The value represents the `scf.if` op result and the operand/ represents the // edge into one of the branches. SmallVector> hoistAbove;