|
| 1 | +// RUN: iree-opt --pass-pipeline='builtin.module(func.func(iree-amdaie-vectorization))' %s | FileCheck %s |
| 2 | + |
| 3 | +// Make sure it's not falling on the InnerParallel pattern |
| 4 | +// CHECK-LABEL: func.func @multi_reduction_innerparallel |
| 5 | +func.func @multi_reduction_innerparallel(%v : vector<4xf16>, %acc: f16) -> f16 { |
| 6 | + // CHECK-NOT: vector.extract %{{.*}}[0] : f16 from vector<4xf16> |
| 7 | + // CHECK-NOT: arith.addf %{{.*}}, %{{.*}} : f16 |
| 8 | + %0 = vector.multi_reduction <add>, %v, %acc[0] : vector<4xf16> to f16 |
| 9 | + return %0 : f16 |
| 10 | +} |
| 11 | + |
| 12 | +// CHECK-LABEL: func.func @multi_reduction_2d |
| 13 | +func.func @multi_reduction_2d(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> { |
| 14 | + // CHECK: vector.reduction <add>, |
| 15 | + // CHECK: vector.reduction <add>, |
| 16 | + // CHECK: vector.reduction <add>, |
| 17 | + // CHECK: vector.reduction <add>, |
| 18 | + // CHECK-NOT: vector.reduction <add>, |
| 19 | + %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32> |
| 20 | + return %0 : vector<4xf32> |
| 21 | +} |
| 22 | + |
| 23 | +// CHECK-LABEL: func.func @multi_reduction_1d |
| 24 | +func.func @multi_reduction_1d(%v : vector<4xf32>, %acc: f32) -> f32 { |
| 25 | + // CHECK: vector.reduction <add>, %{{.*}}, %{{.*}} : vector<4xf32> into f32 |
| 26 | + %0 = vector.multi_reduction <add>, %v, %acc[0] : vector<4xf32> to f32 |
| 27 | + return %0 : f32 |
| 28 | +} |
| 29 | + |
| 30 | +// CHECK-LABEL: func.func @multi_reduction_1d_mul |
| 31 | +func.func @multi_reduction_1d_mul(%v : vector<4xf32>, %acc: f32) -> f32 { |
| 32 | + // CHECK: vector.reduction <mul>, %{{.*}}, %{{.*}} : vector<4xf32> into f32 |
| 33 | + %0 = vector.multi_reduction <mul>, %v, %acc[0] : vector<4xf32> to f32 |
| 34 | + return %0 : f32 |
| 35 | +} |
| 36 | + |
| 37 | +// CHECK-LABEL: func.func @multi_reduction_1d_bf16 |
| 38 | +func.func @multi_reduction_1d_bf16(%v : vector<32xbf16>, %acc: bf16) -> bf16 { |
| 39 | + // CHECK: vector.reduction <add>, %{{.*}}, %{{.*}} : vector<32xbf16> into bf16 |
| 40 | + %0 = vector.multi_reduction <add>, %v, %acc[0] : vector<32xbf16> to bf16 |
| 41 | + return %0 : bf16 |
| 42 | +} |
0 commit comments