Skip to content

Commit c4577b1

Browse files
committed
[mlir][bufferization]-Add lit tests for unhandled cases in EmptyTensorElimination
In many cases the emptyTensorElimination can not transform or eliminate the empty tensor which is being inserted into the `SubsetInsertionOpInterface`. Two major reasons for that: 1- Failing when trying to find a legal/suitable insertion point for the `subsetExtract` which is about to replace the empty tensor. However, we may try to handle this issue by moving the needed values which responsible on building the `subsetExtract` nearby the empty tensor (which is about to be eliminated). Thus increasing the probability to find a legal insertion point. 2-The EmptyTensorElimination transform replaces the tensor.empty's uses all at once in one apply, rather than replacing only the specific use which was visited in the use-def chain (when traversing from the tensor.insert_slice). This scenario of replacing all the uses of the tensor.empty may lead into additional read effects after bufferization of the specific subset extract/subview which should not be the case. Both cases may result in many copies in the coming bufferization which can not be canonicalized. The first case can be noticed when having a `tensor.empty` followed by `SubsetInsertionOpInterface` (or in simple words `tensor.insert_slice`), which have been lowered from `tensor/tosa.concat`. The second case can be noticed when having a `tensor.empty`, with many uses and leading to applying the transformation only once, since the whole uses have been replaced at once. This MR only adds the lit tests for the cases shown above (NFC), to emphasize how the transform works, in the coming MRs will upload a slight changes to handle these case.
1 parent 323bedd commit c4577b1

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,101 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
365365
bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
366366
return
367367
}
368+
369+
// -----
370+
371+
// `EmptyTensorElimination` fails to find a valid insertion
372+
// point for the new injected `SubsetExtraction`.
373+
// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
374+
func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
375+
%cst_1 = arith.constant 1.0 : f32
376+
%cst_2 = arith.constant 2.0 : f32
377+
// CHECK: memref.alloc
378+
// CHECK: memref.alloc
379+
// CHECK: memref.alloc
380+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
381+
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
382+
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
383+
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
384+
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
385+
// CHECK: memref.copy
386+
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
387+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
388+
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
389+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
390+
return %inserted_slice_2 : tensor<5x6x128xf32>
391+
}
392+
393+
// -----
394+
395+
// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
396+
func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
397+
%cst_1 = arith.constant 1.0 : f32
398+
%cst_2 = arith.constant 2.0 : f32
399+
// CHECK: memref.alloc
400+
// CHECK: memref.alloc
401+
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
402+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
403+
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
404+
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
405+
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
406+
// CHECK: memref.copy
407+
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
408+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
409+
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
410+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
411+
return %inserted_slice_2 : tensor<5x6x128xf32>
412+
}
413+
414+
// -----
415+
416+
// `EmptyTensorElimination` replaces all of the uses of the tensor
417+
// empty with the new injected `SubsetExtraction`, without to consider
418+
// the specific use has been tracked, sometimes creating a non existent
419+
// bufferization conflicts.
420+
421+
// CHECK-ELIM-LABEL: func.func @mutli_use_of_the_same_tensor_empty
422+
// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty
423+
func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
424+
%cst_1 = arith.constant 1.0 : f32
425+
%cst_2 = arith.constant 2.0 : f32
426+
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
427+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
428+
// CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
429+
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
430+
// CHECK-ELIM: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
431+
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
432+
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
433+
// CHECK: memref.copy
434+
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
435+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
436+
// CHECK: memref.copy
437+
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
438+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
439+
return %inserted_slice_2 : tensor<5x6x128xf32>
440+
}
441+
442+
// -----
443+
444+
// CHECK-LABEL: func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read
445+
func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>)
446+
-> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
447+
%cst_1 = arith.constant 1.0 : f32
448+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
449+
// CHECK: memref.alloc
450+
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
451+
%res_2 = linalg.generic{
452+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
453+
iterator_types = ["parallel", "parallel", "parallel"]
454+
}
455+
ins(%empty_1 : tensor<5x6x64xf32>)
456+
outs(%arg2 :tensor<5x6x64xf32>) {
457+
^bb0(%in: f32, %out: f32):
458+
%res = arith.addf %in, %in : f32
459+
linalg.yield %res : f32
460+
} -> tensor<5x6x64xf32>
461+
// CHECK: memref.copy
462+
%inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
463+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
464+
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
465+
}

0 commit comments

Comments
 (0)