From 40176253d93bd46178677cfc690cc99a607b7c3f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 26 Apr 2025 12:10:22 +0200 Subject: [PATCH] [mlir][sparse_tensor] Fix memory leak in `sparse_index_dense.mlir` --- .../SparseTensor/CPU/sparse_index_dense.mlir | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir index 371b3f359f3bf..407c06077da9f 100644 --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_index_dense.mlir @@ -65,8 +65,9 @@ module { // // Kernel that uses index in the index notation (conjunction). // - func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>, - %out: tensor<8xi64>) -> tensor<8xi64> { + func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) + -> tensor<8xi64> { + %out = tensor.empty() : tensor<8xi64> %r = linalg.generic #trait_1d ins(%arga: tensor<8xi64, #SparseVector>) outs(%out: tensor<8xi64>) { @@ -82,8 +83,9 @@ module { // // Kernel that uses index in the index notation (disjunction). // - func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>, - %out: tensor<8xi64>) -> tensor<8xi64> { + func.func @sparse_index_1d_disj(%arga: tensor<8xi64, #SparseVector>) + -> tensor<8xi64> { + %out = tensor.empty() : tensor<8xi64> %r = linalg.generic #trait_1d ins(%arga: tensor<8xi64, #SparseVector>) outs(%out: tensor<8xi64>) { @@ -99,8 +101,9 @@ module { // // Kernel that uses indices in the index notation (conjunction). // - func.func @sparse_index_2d_conj(%arga: tensor<3x4xi64, #SparseMatrix>, - %out: tensor<3x4xi64>) -> tensor<3x4xi64> { + func.func @sparse_index_2d_conj(%arga: tensor<3x4xi64, #SparseMatrix>) + -> tensor<3x4xi64> { + %out = tensor.empty() : tensor<3x4xi64> %r = linalg.generic #trait_2d ins(%arga: tensor<3x4xi64, #SparseMatrix>) outs(%out: tensor<3x4xi64>) { @@ -119,8 +122,9 @@ module { // // Kernel that uses indices in the index notation (disjunction). // - func.func @sparse_index_2d_disj(%arga: tensor<3x4xi64, #SparseMatrix>, - %out: tensor<3x4xi64>) -> tensor<3x4xi64> { + func.func @sparse_index_2d_disj(%arga: tensor<3x4xi64, #SparseMatrix>) + -> tensor<3x4xi64> { + %out = tensor.empty() : tensor<3x4xi64> %r = linalg.generic #trait_2d ins(%arga: tensor<3x4xi64, #SparseMatrix>) outs(%out: tensor<3x4xi64>) { @@ -161,20 +165,15 @@ module { [ 1, 1, 3, 4 ] ]> : tensor<3x4xi64> %dm = sparse_tensor.convert %m2 : tensor<3x4xi64> to tensor<3x4xi64, #SparseMatrix> - // Setup out tensors. - // Note: Constants bufferize to read-only buffers. - %init_8 = tensor.empty() : tensor<8xi64> - %init_3_4 = tensor.empty() : tensor<3x4xi64> - // Call the kernels. - %0 = call @sparse_index_1d_conj(%sv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64> - %1 = call @sparse_index_1d_disj(%sv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64> - %2 = call @sparse_index_1d_conj(%dv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64> - %3 = call @sparse_index_1d_disj(%dv, %init_8) : (tensor<8xi64, #SparseVector>, tensor<8xi64>) -> tensor<8xi64> - %4 = call @sparse_index_2d_conj(%sm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64> - %5 = call @sparse_index_2d_disj(%sm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64> - %6 = call @sparse_index_2d_conj(%dm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64> - %7 = call @sparse_index_2d_disj(%dm, %init_3_4) : (tensor<3x4xi64, #SparseMatrix>, tensor<3x4xi64>) -> tensor<3x4xi64> + %0 = call @sparse_index_1d_conj(%sv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64> + %1 = call @sparse_index_1d_disj(%sv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64> + %2 = call @sparse_index_1d_conj(%dv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64> + %3 = call @sparse_index_1d_disj(%dv) : (tensor<8xi64, #SparseVector>) -> tensor<8xi64> + %4 = call @sparse_index_2d_conj(%sm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64> + %5 = call @sparse_index_2d_disj(%sm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64> + %6 = call @sparse_index_2d_conj(%dm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64> + %7 = call @sparse_index_2d_disj(%dm) : (tensor<3x4xi64, #SparseMatrix>) -> tensor<3x4xi64> // // Verify result.