Skip to content

Commit f388a3a

Browse files
authored
[mlir][sparse] update doc and examples of the [dis]assemble operations (#88213)
The doc and examples of the [dis]assemble operations did not reflect all the recent changes on order of the operands. Also clarified some of the text.
1 parent 3d46856 commit f388a3a

File tree

4 files changed

+61
-58
lines changed

4 files changed

+61
-58
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -58,46 +58,44 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
5858
Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
5959
TensorOf<[AnyType]>:$values)>,
6060
Results<(outs AnySparseTensor: $result)> {
61-
let summary = "Returns a sparse tensor assembled from the given values and levels";
61+
let summary = "Returns a sparse tensor assembled from the given levels and values";
6262

6363
let description = [{
64-
Assembles the values and per-level coordinate or postion arrays into a sparse tensor.
65-
The order and types of provided levels must be consistent with the actual storage
66-
layout of the returned sparse tensor described below.
64+
Assembles the per-level position and coordinate arrays together with
65+
the values arrays into a sparse tensor. The order and types of the
66+
provided levels must be consistent with the actual storage layout of
67+
the returned sparse tensor described below.
6768

68-
- `values : tensor<? x V>`
69-
supplies the value for each stored element in the sparse tensor.
7069
- `levels: [tensor<? x iType>, ...]`
71-
each supplies the sparse tensor coordinates scheme in the sparse tensor for
72-
the corresponding level as specifed by `sparse_tensor::StorageLayout`.
73-
74-
This operation can be used to assemble a sparse tensor from external
75-
sources; e.g., when passing two numpy arrays from Python.
76-
77-
Disclaimer: This is the user's responsibility to provide input that can be
78-
correctly interpreted by the sparsifier, which does not perform
79-
any sanity test during runtime to verify data integrity.
70+
supplies the sparse tensor position and coordinate arrays
71+
of the sparse tensor for the corresponding level as specifed by
72+
`sparse_tensor::StorageLayout`.
73+
- `values : tensor<? x V>`
74+
supplies the values array for the stored elements in the sparse tensor.
8075

81-
TODO: The returned tensor is allowed (in principle) to have non-identity
82-
dimOrdering/higherOrdering mappings. However, the current implementation
83-
does not yet support them.
76+
This operation can be used to assemble a sparse tensor from an
77+
external source; e.g., by passing numpy arrays from Python. It
78+
is the user's responsibility to provide input that can be correctly
79+
interpreted by the sparsifier, which does not perform any sanity
80+
test to verify data integrity.
8481

8582
Example:
8683

8784
```mlir
88-
%values = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
89-
%coordinates = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
90-
%st = sparse_tensor.assemble %values, %coordinates
91-
: tensor<3xf64>, tensor<3x2xindex> to tensor<3x4xf64, #COO>
85+
%pos = arith.constant dense<[0, 3]> : tensor<2xindex>
86+
%index = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
87+
%values = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
88+
%s = sparse_tensor.assemble (%pos, %index), %values
89+
: (tensor<2xindex>, tensor<3x2xindex>), tensor<3xf64> to tensor<3x4xf64, #COO>
9290
// yields COO format |1.1, 0.0, 0.0, 0.0|
9391
// of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
9492
// |0.0, 0.0, 0.0, 0.0|
9593
```
9694
}];
9795

9896
let assemblyFormat =
99-
"` ` `(` $levels `)` `,` $values attr-dict"
100-
" `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
97+
"` ` `(` $levels `)` `,` $values attr-dict `:`"
98+
" `(` type($levels) `)` `,` type($values) `to` type($result)";
10199

102100
let hasVerifier = 1;
103101
}
@@ -110,48 +108,47 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
110108
TensorOf<[AnyType]>:$ret_values,
111109
Variadic<AnyIndexingScalarLike>:$lvl_lens,
112110
AnyIndexingScalarLike:$val_len)> {
113-
let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
111+
let summary = "Copies the levels and values of the given sparse tensor";
114112

115113
let description = [{
116114
The disassemble operation is the inverse of `sparse_tensor::assemble`.
117-
It returns the values and per-level position and coordinate array to the
118-
user from the sparse tensor along with the actual length of the memory used
119-
in each returned buffer. This operation can be used for returning an
120-
disassembled MLIR sparse tensor to frontend; e.g., returning two numpy arrays
121-
to Python.
122-
123-
Disclaimer: This is the user's responsibility to allocate large enough buffers
124-
to hold the sparse tensor. The sparsifier simply copies each fields
125-
of the sparse tensor into the user-supplied buffer without bound checking.
115+
It copies the per-level position and coordinate arrays together with
116+
the values array of the given sparse tensor into the user-supplied buffers
117+
along with the actual length of the memory used in each returned buffer.
126118

127-
TODO: the current implementation does not yet support non-identity mappings.
119+
This operation can be used for returning a disassembled MLIR sparse tensor;
120+
e.g., copying the sparse tensor contents into pre-allocated numpy arrays
121+
back to Python. It is the user's responsibility to allocate large enough
122+
buffers of the appropriate types to hold the sparse tensor contents.
123+
The sparsifier simply copies all fields of the sparse tensor into the
124+
user-supplied buffers without any sanity test to verify data integrity.
128125

129126
Example:
130127

131128
```mlir
132129
// input COO format |1.1, 0.0, 0.0, 0.0|
133130
// of 3x4 matrix |0.0, 0.0, 2.2, 3.3|
134131
// |0.0, 0.0, 0.0, 0.0|
135-
%v, %p, %c, %v_len, %p_len, %c_len =
136-
sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
137-
out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
138-
out_vals(%od) : tensor<3xf64> ->
139-
tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
140-
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
132+
%p, %c, %v, %p_len, %c_len, %v_len =
133+
sparse_tensor.disassemble %s : tensor<3x4xf64, #COO>
134+
out_lvls(%op, %oi : tensor<2xindex>, tensor<3x2xindex>)
135+
out_vals(%od : tensor<3xf64>) ->
136+
(tensor<2xindex>, tensor<3x2xindex>), tensor<3xf64>, (index, index), index
141137
// %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
142138
// %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
143-
// %v_len = 3
139+
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
144140
// %p_len = 2
145141
// %c_len = 6 (3x2)
142+
// %v_len = 3
146143
```
147144
}];
148145

149146
let assemblyFormat =
150-
"$tensor `:` type($tensor) "
147+
"$tensor attr-dict `:` type($tensor)"
151148
"`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
152-
"`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
153-
"`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
154-
"`(` type($lvl_lens) `)` `,` type($val_len)";
149+
"`out_vals` `(` $out_values `:` type($out_values) `)` `->`"
150+
"`(` type($ret_levels) `)` `,` type($ret_values) `,` "
151+
"`(` type($lvl_lens) `)` `,` type($val_len)";
155152

156153
let hasVerifier = 1;
157154
}

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func.func @invalid_pack_mis_position(%values: tensor<6xf64>, %coordinates: tenso
6060

6161
func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x1xi32>) {
6262
// expected-error@+1 {{input/output element-types don't match}}
63-
%rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf32, #SparseVector>
63+
%rp, %rc, %rv, %pl, %cl, %vl = sparse_tensor.disassemble %sp : tensor<100xf32, #SparseVector>
6464
out_lvls(%pos, %coordinates : tensor<2xi32>, tensor<6x1xi32>)
6565
out_vals(%values : tensor<6xf64>)
6666
-> (tensor<2xi32>, tensor<6x1xi32>), tensor<6xf64>, (index, index), index
@@ -73,7 +73,7 @@ func.func @invalid_unpack_type(%sp: tensor<100xf32, #SparseVector>, %values: ten
7373

7474
func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: tensor<6xf64>, %pos: tensor<2xi32>, %coordinates: tensor<6x3xi32>) {
7575
// expected-error@+1 {{input/output trailing COO level-ranks don't match}}
76-
%rv, %rp, %rc, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100x2xf64, #SparseVector>
76+
%rp, %rc, %rv, %pl, %cl, %vl = sparse_tensor.disassemble %sp : tensor<100x2xf64, #SparseVector>
7777
out_lvls(%pos, %coordinates : tensor<2xi32>, tensor<6x3xi32> )
7878
out_vals(%values : tensor<6xf64>)
7979
-> (tensor<2xi32>, tensor<6x3xi32>), tensor<6xf64>, (index, index), index
@@ -86,7 +86,7 @@ func.func @invalid_unpack_type(%sp: tensor<100x2xf64, #SparseVector>, %values: t
8686

8787
func.func @invalid_unpack_mis_position(%sp: tensor<2x100xf64, #CSR>, %values: tensor<6xf64>, %coordinates: tensor<6xi32>) {
8888
// expected-error@+1 {{inconsistent number of fields between input/output}}
89-
%rv, %rc, %vl, %pl = sparse_tensor.disassemble %sp : tensor<2x100xf64, #CSR>
89+
%rc, %rv, %cl, %vl = sparse_tensor.disassemble %sp : tensor<2x100xf64, #CSR>
9090
out_lvls(%coordinates : tensor<6xi32>)
9191
out_vals(%values : tensor<6xf64>)
9292
-> (tensor<6xi32>), tensor<6xf64>, (index), index

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,21 @@ func.func @sparse_pack(%pos: tensor<2xi32>, %index: tensor<6x1xi32>, %data: tens
3333
#SparseVector = #sparse_tensor.encoding<{map = (d0) -> (d0 : compressed), crdWidth=32}>
3434
// CHECK-LABEL: func @sparse_unpack(
3535
// CHECK-SAME: %[[T:.*]]: tensor<100xf64, #
36-
// CHECK-SAME: %[[OD:.*]]: tensor<6xf64>
37-
// CHECK-SAME: %[[OP:.*]]: tensor<2xindex>
38-
// CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>
36+
// CHECK-SAME: %[[OP:.*]]: tensor<2xindex>,
37+
// CHECK-SAME: %[[OI:.*]]: tensor<6x1xi32>,
38+
// CHECK-SAME: %[[OD:.*]]: tensor<6xf64>)
3939
// CHECK: %[[P:.*]]:2, %[[D:.*]], %[[PL:.*]]:2, %[[DL:.*]] = sparse_tensor.disassemble %[[T]]
4040
// CHECK: return %[[P]]#0, %[[P]]#1, %[[D]]
4141
func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>,
42-
%od : tensor<6xf64>,
4342
%op : tensor<2xindex>,
44-
%oi : tensor<6x1xi32>)
43+
%oi : tensor<6x1xi32>,
44+
%od : tensor<6xf64>)
4545
-> (tensor<2xindex>, tensor<6x1xi32>, tensor<6xf64>) {
46-
%rp, %ri, %rd, %vl, %pl, %cl = sparse_tensor.disassemble %sp : tensor<100xf64, #SparseVector>
46+
%rp, %ri, %d, %rpl, %ril, %dl = sparse_tensor.disassemble %sp : tensor<100xf64, #SparseVector>
4747
out_lvls(%op, %oi : tensor<2xindex>, tensor<6x1xi32>)
4848
out_vals(%od : tensor<6xf64>)
4949
-> (tensor<2xindex>, tensor<6x1xi32>), tensor<6xf64>, (index, index), index
50-
return %rp, %ri, %rd : tensor<2xindex>, tensor<6x1xi32>, tensor<6xf64>
50+
return %rp, %ri, %d : tensor<2xindex>, tensor<6x1xi32>, tensor<6xf64>
5151
}
5252

5353
// -----

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ module {
231231
%od = tensor.empty() : tensor<3xf64>
232232
%op = tensor.empty() : tensor<2xi32>
233233
%oi = tensor.empty() : tensor<3x2xi32>
234-
%p, %i, %d, %dl, %pl, %il = sparse_tensor.disassemble %s5 : tensor<10x10xf64, #SortedCOOI32>
234+
%p, %i, %d, %pl, %il, %dl = sparse_tensor.disassemble %s5 : tensor<10x10xf64, #SortedCOOI32>
235235
out_lvls(%op, %oi : tensor<2xi32>, tensor<3x2xi32>)
236236
out_vals(%od : tensor<3xf64>)
237237
-> (tensor<2xi32>, tensor<3x2xi32>), tensor<3xf64>, (i32, i64), index
@@ -244,10 +244,13 @@ module {
244244
%vi = vector.transfer_read %i[%c0, %c0], %i0 : tensor<3x2xi32>, vector<3x2xi32>
245245
vector.print %vi : vector<3x2xi32>
246246

247+
// CHECK-NEXT: 3
248+
vector.print %dl : index
249+
247250
%d_csr = tensor.empty() : tensor<4xf64>
248251
%p_csr = tensor.empty() : tensor<3xi32>
249252
%i_csr = tensor.empty() : tensor<3xi32>
250-
%rp_csr, %ri_csr, %rd_csr, %ld_csr, %lp_csr, %li_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
253+
%rp_csr, %ri_csr, %rd_csr, %lp_csr, %li_csr, %ld_csr = sparse_tensor.disassemble %csr : tensor<2x2xf64, #CSR>
251254
out_lvls(%p_csr, %i_csr : tensor<3xi32>, tensor<3xi32>)
252255
out_vals(%d_csr : tensor<4xf64>)
253256
-> (tensor<3xi32>, tensor<3xi32>), tensor<4xf64>, (i32, i64), index
@@ -256,6 +259,9 @@ module {
256259
%vd_csr = vector.transfer_read %rd_csr[%c0], %f0 : tensor<4xf64>, vector<3xf64>
257260
vector.print %vd_csr : vector<3xf64>
258261

262+
// CHECK-NEXT: 3
263+
vector.print %ld_csr : index
264+
259265
%bod = tensor.empty() : tensor<6xf64>
260266
%bop = tensor.empty() : tensor<4xindex>
261267
%boi = tensor.empty() : tensor<6x2xindex>

0 commit comments

Comments
 (0)