Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 65 additions & 13 deletions mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s

// TODO: Replace %arg0 with %vec
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, do you think it would make sense to refactor all args the same way you did in vector-transfer-flatten.mlir ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge +1 - I will update the comment to reflect that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you updated the file AND accepted my suggestion generating repetition 😕

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes! Fixed in the latest commit :)


///----------------------------------------------------------------------------------------
/// vector.transfer_write -> vector.transpose + vector.transfer_write
/// [Pattern: TransferWritePermutationLowering]
Expand Down Expand Up @@ -31,6 +33,31 @@ func.func @xfer_write_transposing_permutation_map(
return
}

// Even with out-of-bounds, it is safe to apply this pattern
// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_out_of_bounds
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x8xi16>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Argument definitions are aligned with function body in the test below. Although xfer_write_transposing_permutation_map also follows this formatting. Both are valid I guess. 🤷

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's prioritise consistency - I will update this before landing. Thanks!

// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x?xi16>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[TR:.*]] = vector.transpose %[[ARG_0]], [1, 0] : vector<4x8xi16> to vector<8x4xi16>
// Expect the in_bounds attribute to be preserved. Since we don't print it when
// all flags are "false", it should not appear in the output.
// CHECK-NOT: in_bounds
// CHECK: vector.transfer_write
// CHECK-NOT: permutation_map
// CHECK-SAME: %[[TR]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x4xi16>, memref<2x2x?x?xi16>
func.func @xfer_write_transposing_permutation_map_out_of_bounds(
%arg0: vector<4x8xi16>,
%mem: memref<2x2x?x?xi16>) {

%c0 = arith.constant 0 : index
vector.transfer_write %arg0, %mem[%c0, %c0, %c0, %c0] {
in_bounds = [false, false],
permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
} : vector<4x8xi16>, memref<2x2x?x?xi16>

return
}

// CHECK-LABEL: func.func @xfer_write_transposing_permutation_map_with_mask_scalable
// CHECK-SAME: %[[ARG_0:.*]]: vector<4x[8]xi16>,
// CHECK-SAME: %[[MEM:.*]]: memref<2x2x?x4xi16>,
Expand Down Expand Up @@ -83,19 +110,44 @@ func.func @xfer_write_transposing_permutation_map_masked(
/// * vector.broadcast + vector.transpose + vector.transfer_write with a map
/// which _is_ a permutation of a minor identity

// CHECK-LABEL: func @permutation_with_mask_xfer_write_fixed_width(
// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32>
// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1>
// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1>
// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
func.func @permutation_with_mask_xfer_write_fixed_width(%mem : memref<?x?xf32>, %base1 : index,
%base2 : index) {

%fn1 = arith.constant -2.0 : f32
%vf0 = vector.splat %fn1 : vector<7xf32>
%mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1>
vector.transfer_write %vf0, %mem[%base1, %base2], %mask
// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
// CHECK-SAME: %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index) {
// CHECK: %[[BC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
// CHECK: %[[TR:.*]] = vector.transpose %[[BC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
// CHECK: vector.transfer_write %[[TR]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
func.func @xfer_write_non_transposing_permutation_map(
%mem : memref<?x?xf32>,
%arg0 : vector<7xf32>,
%base1 : index,
%base2 : index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: I'd suggest to use %idx1 to match with other tests such as masked_permutation_xfer_write_fixed_width and your refacto of transfer-flatten.
I am not unhappy if you make it as part of a separate commit to tackle the above TODO. Also getting rid of %dim args and use %idx.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind me expanding this PR, I'll do it here. Ultimately, this needs to happen and it's all about making the review process as smooth/easy as possible. If you are happy then I'm also happy :)

Also getting rid of %dim args and use %idx.

This one I'm a bit unsure of - %dim is used for mask dimension rather than xfer_{read|write} index 🤔


vector.transfer_write %arg0, %mem[%base1, %base2]
{permutation_map = affine_map<(d0, d1) -> (d0)>}
: vector<7xf32>, memref<?x?xf32>
return
}

// Even with out-of-bounds, it is safe to apply this pattern
// CHECK-LABEL: func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<7xf32>,
// CHECK-SAME: %[[BASE_1:.*]]: index, %[[BASE_2:.*]]: index,
// CHECK-SAME: %[[MASK:.*]]: vector<7xi1>) {
// CHECK: %[[BC_VEC:.*]] = vector.broadcast %[[VEC]] : vector<7xf32> to vector<1x7xf32>
// CHECK: %[[BC_MASK:.*]] = vector.broadcast %[[MASK]] : vector<7xi1> to vector<1x7xi1>
// CHECK: %[[TR_MASK:.*]] = vector.transpose %[[BC_MASK]], [1, 0] : vector<1x7xi1> to vector<7x1xi1>
// CHECK: %[[TR_VEC:.*]] = vector.transpose %[[BC_VEC]], [1, 0] : vector<1x7xf32> to vector<7x1xf32>
// CHECK: vector.transfer_write %[[TR_VEC]], %[[MEM]]{{\[}}%[[BASE_1]], %[[BASE_2]]], %[[TR_MASK]] {in_bounds = [false, true]} : vector<7x1xf32>, memref<?x?xf32>
func.func @xfer_write_non_transposing_permutation_map_with_mask_out_of_bounds(
%mem : memref<?x?xf32>,
%arg0 : vector<7xf32>,
%base1 : index,
%base2 : index,
%mask : vector<7xi1>) {

vector.transfer_write %arg0, %mem[%base1, %base2], %mask
{permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]}
: vector<7xf32>, memref<?x?xf32>
return
Expand Down