Skip to content

Commit ccd1a5b

Browse files
committed
[mlir][sparse] Fix a bug in rewriting for the convert op.
The code to retrieve the number of entries isn't correct. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D137795
1 parent 6a8d894 commit ccd1a5b

File tree

3 files changed

+36
-45
lines changed

3 files changed

+36
-45
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -666,12 +666,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
666666
}
667667

668668
// Retrieve NNZ.
669-
auto ptrTp =
670-
MemRefType::get(dynShape, getPointerOverheadType(rewriter, encSrc));
671-
Value p0 =
672-
rewriter.create<ToIndicesOp>(loc, ptrTp, src, rewriter.getIndexAttr(0));
673-
Value c1 = constantIndex(rewriter, loc, 1);
674-
Value nnz = rewriter.create<memref::LoadOp>(loc, p0, c1);
669+
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
675670
nnz =
676671
rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), nnz);
677672

mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -105,35 +105,34 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
105105

106106
// CHECK-RWT-LABEL: func.func @sparse_convert_2d(
107107
// CHECK-RWT-SAME: %[[T0:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
108-
// CHECK-RWT: %[[VAL_1:.*]] = arith.constant 1 : index
109-
// CHECK-RWT: %[[VAL_2:.*]] = bufferization.alloc_tensor()
110-
// CHECK-RWT: %[[VAL_3:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[VAL_2]])
111-
// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: tensor
112-
// CHECK-RWT: %[[CMP:.*]] = arith.cmpf une, %[[VAL_6]]
108+
// CHECK-RWT: %[[T1:.*]] = bufferization.alloc_tensor()
109+
// CHECK-RWT: %[[T2:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[T1]])
110+
// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f64, %[[L0T:.*]]: tensor
111+
// CHECK-RWT: %[[CMP:.*]] = arith.cmpf une, %[[L0V]]
113112
// CHECK-RWT: %[[IFR:.*]] = scf.if %[[CMP]]
114-
// CHECK-RWT: %[[Y1:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]
115-
// CHECK-RWT: scf.yield %[[Y1]]
113+
// CHECK-RWT: %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I0]], %[[L0I1]]]
114+
// CHECK-RWT: scf.yield %[[L0T2]]
116115
// CHECK-RWT: } else {
117-
// CHECK-RWT: scf.yield %[[VAL_7]]
116+
// CHECK-RWT: scf.yield %[[L0T]]
118117
// CHECK-RWT: }
119118
// CHECK-RWT: sparse_tensor.yield %[[IFR]]
120119
// CHECK-RWT: }
121-
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts
120+
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T2]] hasInserts
122121
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
123122
// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
124-
// CHECK-RWT: %[[VAL_13:.*]] = memref.load %[[I0]]{{\[}}%[[VAL_1]]] : memref<?xindex>
125-
// CHECK-RWT: %[[VAL_14:.*]] = sparse_tensor.values %[[COO]]
126-
// CHECK-RWT: sparse_tensor.sort %[[VAL_13]], %[[I0]], %[[I1]] jointly %[[VAL_14]] : memref<?xindex>, memref<?xindex> jointly memref<?xf64>
127-
// CHECK-RWT: %[[VAL_15:.*]] = bufferization.alloc_tensor()
128-
// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[VAL_15]])
129-
// CHECK-RWT: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index, %[[VAL_19:.*]]: f64, %[[VAL_20:.*]]: tensor
130-
// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.insert %[[VAL_19]] into %[[VAL_20]]{{\[}}%[[VAL_17]], %[[VAL_18]]]
131-
// CHECK-RWT: sparse_tensor.yield %[[VAL_21]]
123+
// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
124+
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]]
125+
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref<?xindex>, memref<?xindex> jointly memref<?xf64>
126+
// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor()
127+
// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
128+
// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f64, %[[L1T:.*]]: tensor
129+
// CHECK-RWT: %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I0]], %[[L1I1]]]
130+
// CHECK-RWT: sparse_tensor.yield %[[L1T2]]
132131
// CHECK-RWT: }
133-
// CHECK-RWT: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_16]] hasInserts
134-
// CHECK-RWT: %[[VAL_24:.*]] = sparse_tensor.convert %[[VAL_22]]
132+
// CHECK-RWT: %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts
133+
// CHECK-RWT: %[[T6:.*]] = sparse_tensor.convert %[[T5]]
135134
// CHECK-RWT: bufferization.dealloc_tensor %[[COO]]
136-
// CHECK-RWT: return %[[VAL_24]]
135+
// CHECK-RWT: return %[[T6]]
137136
// CHECK-RWT: }
138137
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
139138
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64> to tensor<2x4xf64, #CSR>
@@ -169,30 +168,29 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
169168
// CHECK: return %[[T]] : !llvm.ptr<i8>
170169

171170
// CHECK-RWT-LABEL: func.func @sparse_constant() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
172-
// CHECK-RWT: %[[VAL_0:.*]] = arith.constant 1 : index
173-
// CHECK-RWT: %[[VAL_1:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32>
171+
// CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32>
174172
// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor()
175-
// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]])
176-
// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: tensor
177-
// CHECK-RWT: %[[T2:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]]
178-
// CHECK-RWT: sparse_tensor.yield %[[T2]]
173+
// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[F0]] init(%[[T0]])
174+
// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f32, %[[L0T:.*]]: tensor
175+
// CHECK-RWT: %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I0]], %[[L0I1]]]
176+
// CHECK-RWT: sparse_tensor.yield %[[L0T2]]
179177
// CHECK-RWT: }
180178
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts
181179
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
182180
// CHECK-RWT: %[[I1:.*]] = sparse_tensor.indices %[[COO]] {dimension = 1 : index}
183-
// CHECK-RWT: %[[VAL_13:.*]] = memref.load %[[I0]]{{\[}}%[[VAL_0]]] : memref<?xindex>
181+
// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
184182
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]]
185-
// CHECK-RWT: sparse_tensor.sort %[[VAL_13]], %[[I0]], %[[I1]] jointly %[[V]] : memref<?xindex>, memref<?xindex> jointly memref<?xf32>
186-
// CHECK-RWT: %[[VAL_15:.*]] = bufferization.alloc_tensor()
187-
// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[VAL_15]])
188-
// CHECK-RWT: ^bb0(%[[VAL_17:.*]]: index, %[[VAL_18:.*]]: index, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: tensor
189-
// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.insert %[[VAL_19]] into %[[VAL_20]]{{\[}}%[[VAL_17]], %[[VAL_18]]]
190-
// CHECK-RWT: sparse_tensor.yield %[[VAL_21]]
183+
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V]] : memref<?xindex>, memref<?xindex> jointly memref<?xf32>
184+
// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor()
185+
// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
186+
// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor
187+
// CHECK-RWT: %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I0]], %[[L1I1]]]
188+
// CHECK-RWT: sparse_tensor.yield %[[L1T2]]
191189
// CHECK-RWT: }
192-
// CHECK-RWT: %[[VAL_22:.*]] = sparse_tensor.load %[[VAL_16]] hasInserts
193-
// CHECK-RWT: %[[VAL_24:.*]] = sparse_tensor.convert %[[VAL_22]]
190+
// CHECK-RWT: %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts
191+
// CHECK-RWT: %[[T6:.*]] = sparse_tensor.convert %[[T5]]
194192
// CHECK-RWT: bufferization.dealloc_tensor %[[COO]]
195-
// CHECK-RWT: return %[[VAL_24]]
193+
// CHECK-RWT: return %[[T6]]
196194
// CHECK-RWT: }
197195
func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
198196
// Initialize a tensor.

mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
8686
// CHECK-RWT-LABEL: func.func @sparse_convert(
8787
// CHECK-RWT-SAME: %[[A:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 64, indexBitWidth = 64 }>>)
8888
// CHECK-RWT-DAG: %[[C0:.*]] = arith.constant 0 : index
89-
// CHECK-RWT-DAG: %[[C1:.*]] = arith.constant 1 : index
9089
// CHECK-RWT: %[[D:.*]] = tensor.dim %[[A]], %[[C0]]
9190
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[A]] {dimension = 0 : index}
92-
// CHECK-RWT: %[[NNZr:.*]] = memref.load %[[I0]]{{\[}}%[[C1]]] : memref<?xi64>
93-
// CHECK-RWT: %[[NNZ:.*]] = arith.index_cast %[[NNZr]] : i64 to index
91+
// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[A]]
9492
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[A]]
9593
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]] jointly %[[V]]
9694
// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])

0 commit comments

Comments
 (0)