diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index 318edca73cce1..8be4e1b79c52c 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -1,76 +1,210 @@ // RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s -// CHECK-LABEL: func @hoist_vector_transfer_pairs( -// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, -// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref, -// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref, -// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref, -// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref, -// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref, -// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, -// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index, -// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index, -// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, -// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 -func.func @hoist_vector_transfer_pairs( - %memref0: memref, %memref1: memref, %memref2: memref, - %memref3: memref, %memref4: memref, %memref5: memref, - %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) { +///---------------------------------------------------------------------------------------- +/// Tests for vector.transfer_read + vector.transfer_write pairs +/// +/// * Nested in double loops +// * Indices depend on induction variables +///---------------------------------------------------------------------------------------- + +// CHECK-LABEL: func @mem_use_outside +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) +func.func @mem_use_outside(%mem: memref, %lb : index, %ub : index, %step: index) { + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[I]], %[[I]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) { +// CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32> +// CHECK: scf.yield %[[USE]] : vector<1xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[I]], %[[I]]] : vector<1xf32>, memref +// CHECK: "mem_use"(%[[MEM]]) : (memref) -> () +// CHECK: } + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %read = vector.transfer_read %mem[%i, %i], %pad: memref, vector<1xf32> + %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32> + vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref + } + } + "mem_use"(%mem) : (memref) -> () + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @mem_use_inside_outer_loop +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) +func.func @mem_use_inside_outer_loop(%mem: memref, %lb : index, %ub : index, %step: index) { + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[I]], %[[I]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) { +// CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32> +// CHECK: scf.yield %[[USE]] : vector<1xf32> +// CHECK: } +// CHECK: vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref +// CHECK: "mem_use"(%[[MEM]]) : (memref) -> () +// CHECK: } + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %read = vector.transfer_read %mem[%i, %i], %pad: memref, vector<1xf32> + %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32> + vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref + } + "mem_use"(%mem) : (memref) -> () + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +///---------------------------------------------------------------------------------------- +/// Tests for vector.transfer_read + vector.transfer_write pairs +/// +/// * Nested in double loops +// * Indices are constant +///---------------------------------------------------------------------------------------- + +// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_write +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) +func.func @negative_mem_use_inside_inner_loop_before_write(%mem: memref, %lb : index, %ub : index, %step: index) { %c0 = arith.constant 0 : index - %cst = arith.constant 0.0 : f32 + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32> +// CHECK: "mem_use"(%[[MEM]]) : (memref) -> () +// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref +// CHECK: } +// CHECK: } + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %read = vector.transfer_read %mem[%c0, %c0], %pad: memref, vector<1xf32> + %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32> + "mem_use"(%mem) : (memref) -> () + vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref + } + } + return +} -// CHECK: vector.transfer_read %{{.*}} : memref, vector<1xf32> -// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) { -// CHECK: vector.transfer_read %{{.*}} : memref, vector<2xf32> -// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) { -// CHECK: vector.transfer_read %{{.*}} : memref, vector<3xf32> -// CHECK: vector.transfer_read %{{.*}} : memref, vector<4xf32> -// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref) -> () -// CHECK: vector.transfer_read %{{.*}} : memref, vector<5xf32> -// CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> -// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> -// CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref, vector<3xf32>) -> vector<3xf32> -// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> -// CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> -// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref -// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, memref -// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, memref -// CHECK: "some_crippling_use"(%[[MEMREF3]]) : (memref) -> () -// CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32> +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_after_write +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) +func.func @negative_mem_use_inside_inner_loop_after_write(%mem: memref, %lb : index, %ub : index, %step: index) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref, vector<1xf32> +// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32> +// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref +// CHECK: "mem_use"(%[[MEM]]) : (memref) -> () +// CHECK: } +// CHECK: } + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %r3 = vector.transfer_read %mem[%c0, %c0], %pad: memref, vector<1xf32> + %u3 = "val_use"(%r3) : (vector<1xf32>) -> vector<1xf32> + vector.transfer_write %u3, %mem[%c0, %c0] : vector<1xf32>, memref + "mem_use"(%mem) : (memref) -> () + } + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_read +// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref, +// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index) +func.func @negative_mem_use_inside_inner_loop_before_read(%mem: memref, %lb : index, %ub : index, %step: index) { + %c0 = arith.constant 0 : index + %pad = arith.constant 0.0 : f32 + +// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: "mem_use"(%[[MEM]]) : (memref) -> () +// CHECK: vector.transfer_read %{{.*}} : memref, vector<1xf32> +// CHECK: "val_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> +// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref // CHECK: } -// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref -// CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref) -> () -// CHECK: scf.yield {{.*}} : vector<1xf32> // CHECK: } -// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref -// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref) -> () scf.for %i = %lb to %ub step %step { scf.for %j = %lb to %ub step %step { - %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref, vector<1xf32> - %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref, vector<2xf32> - %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref, vector<3xf32> - %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref, vector<4xf32> - "some_crippling_use"(%memref4) : (memref) -> () - %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref, vector<5xf32> - %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref, vector<6xf32> - "some_crippling_use"(%memref5) : (memref) -> () - %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> - %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> - %u2 = "some_use"(%memref2, %r2) : (memref, vector<3xf32>) -> vector<3xf32> - %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> - %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> - %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> - vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref - vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref - vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref - vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref - vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref - vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref - "some_crippling_use"(%memref3) : (memref) -> () + "mem_use"(%mem) : (memref) -> () + %read = vector.transfer_read %mem[%c0, %c0], %pad: memref, vector<1xf32> + %use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32> + vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref } - "unrelated_use"(%memref0) : (memref) -> () } - "unrelated_use"(%memref1) : (memref) -> () return } @@ -86,6 +220,12 @@ module attributes {transform.with_named_sequence} { // ----- +///---------------------------------------------------------------------------------------- +/// Other tests +/// +/// TODO: Document +///---------------------------------------------------------------------------------------- + // CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint( // CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, // CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref,