|
1 | 1 | // RUN: mlir-opt %s -test-vector-sink-patterns -split-input-file | FileCheck %s |
| 2 | +// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/vector-sink-transform.mlir' -transform-interpreter -split-input-file %s | FileCheck %s |
2 | 3 |
|
3 | 4 | //----------------------------------------------------------------------------- |
4 | 5 | // [Pattern: ReorderElementwiseOpsOnBroadcast] |
@@ -423,3 +424,92 @@ func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, % |
423 | 424 | %r = arith.addf %at, %bt : vector<6x[4]x2x3xf32> |
424 | 425 | return %r : vector<6x[4]x2x3xf32> |
425 | 426 | } |
| 427 | + |
| 428 | +// ----- |
| 429 | + |
| 430 | +//----------------------------------------------------------------------------- |
| 431 | +// [Pattern: ExtractOpFromElementwise] |
| 432 | +//----------------------------------------------------------------------------- |
| 433 | + |
| 434 | +// CHECK-LABEL: @extract_elementwise_scalar |
| 435 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>) |
| 436 | +func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 { |
| 437 | +// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32> |
| 438 | +// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32> |
| 439 | +// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32 |
| 440 | +// CHECK: return %[[RES]] : f32 |
| 441 | + %0 = arith.addf %arg0, %arg1 : vector<4xf32> |
| 442 | + %1 = vector.extract %0[1] : f32 from vector<4xf32> |
| 443 | + return %1 : f32 |
| 444 | +} |
| 445 | + |
| 446 | +// CHECK-LABEL: @extract_elementwise_arg_res_different_types |
| 447 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xindex>) |
| 448 | +func.func @extract_elementwise_arg_res_different_types(%arg0: vector<4xindex>) -> i64 { |
| 449 | +// CHECK: %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex> |
| 450 | +// CHECK: %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64 |
| 451 | +// CHECK: return %[[RES]] : i64 |
| 452 | + %0 = arith.index_cast %arg0: vector<4xindex> to vector<4xi64> |
| 453 | + %1 = vector.extract %0[1] : i64 from vector<4xi64> |
| 454 | + return %1 : i64 |
| 455 | +} |
| 456 | + |
| 457 | +// CHECK-LABEL: @extract_elementwise_vec |
| 458 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>) |
| 459 | +func.func @extract_elementwise_vec(%arg0: vector<2x4xf32>, %arg1: vector<2x4xf32>) -> vector<4xf32> { |
| 460 | +// CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32> |
| 461 | +// CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32> |
| 462 | +// CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32> |
| 463 | +// CHECK: return %[[RES]] : vector<4xf32> |
| 464 | + %0 = arith.addf %arg0, %arg1 : vector<2x4xf32> |
| 465 | + %1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32> |
| 466 | + return %1 : vector<4xf32> |
| 467 | +} |
| 468 | + |
| 469 | +// CHECK-LABEL: @negative_extract_elementwise_no_single_use |
| 470 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>) |
| 471 | +func.func @negative_extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> (f32, vector<4xf32>) { |
| 472 | +// Do not propagate extract, as elementwise has other uses. |
| 473 | +// CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32> |
| 474 | +// CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32> |
| 475 | +// CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32> |
| 476 | + %0 = arith.addf %arg0, %arg1 : vector<4xf32> |
| 477 | + %1 = vector.extract %0[1] : f32 from vector<4xf32> |
| 478 | + return %1, %0 : f32, vector<4xf32> |
| 479 | +} |
| 480 | + |
| 481 | +// CHECK-LABEL: @negative_extract_elementwise_not_one_res |
| 482 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>) |
| 483 | +func.func @negative_extract_elementwise_not_one_res(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 { |
| 484 | +// Do not propagate extract, as elementwise has more than 1 result. |
| 485 | +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32> |
| 486 | +// CHECK: %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32> |
| 487 | +// CHECK: return %[[EXT]] : i32 |
| 488 | + %low, %hi = arith.mulsi_extended %arg0, %arg1 : vector<4xi32> |
| 489 | + %1 = vector.extract %low[1] : i32 from vector<4xi32> |
| 490 | + return %1 : i32 |
| 491 | +} |
| 492 | + |
| 493 | +// CHECK-LABEL: @negative_extract_not_elementwise |
| 494 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xi64>) |
| 495 | +func.func @negative_extract_not_elementwise(%arg0: vector<4xi64>) -> i64 { |
| 496 | +// `test.increment` is not an elemewise op. |
| 497 | +// CHECK: %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64> |
| 498 | +// CHECK: %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64> |
| 499 | +// CHECK: return %[[RES]] : i64 |
| 500 | + %0 = test.increment %arg0: vector<4xi64> |
| 501 | + %1 = vector.extract %0[1] : i64 from vector<4xi64> |
| 502 | + return %1 : i64 |
| 503 | +} |
| 504 | + |
| 505 | +// CHECK-LABEL: @negative_extract_vec_fma |
| 506 | +// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<4xf32>) |
| 507 | +func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> f32 { |
| 508 | +// `vector.fma` doesn't suppport scalars. |
| 509 | +// CHECK: %[[FMA:.*]] = vector.fma %[[ARG0]], %[[ARG1]], %[[ARG2]] : vector<4xf32> |
| 510 | +// CHECK: %[[RES:.*]] = vector.extract %[[FMA]][1] : f32 from vector<4xf32> |
| 511 | +// CHECK: return %[[RES]] : f32 |
| 512 | + %0 = vector.fma %arg0, %arg1, %arg2: vector<4xf32> |
| 513 | + %1 = vector.extract %0[1] : f32 from vector<4xf32> |
| 514 | + return %1 : f32 |
| 515 | +} |
0 commit comments