@@ -65,8 +65,9 @@ module {
6565 //
6666 // Kernel that uses index in the index notation (conjunction).
6767 //
68- func.func @sparse_index_1d_conj (%arga: tensor <8 xi64 , #SparseVector >,
69- %out: tensor <8 xi64 >) -> tensor <8 xi64 > {
68+ func.func @sparse_index_1d_conj (%arga: tensor <8 xi64 , #SparseVector >)
69+ -> tensor <8 xi64 > {
70+ %out = tensor.empty () : tensor <8 xi64 >
7071 %r = linalg.generic #trait_1d
7172 ins (%arga: tensor <8 xi64 , #SparseVector >)
7273 outs (%out: tensor <8 xi64 >) {
@@ -82,8 +83,9 @@ module {
8283 //
8384 // Kernel that uses index in the index notation (disjunction).
8485 //
85- func.func @sparse_index_1d_disj (%arga: tensor <8 xi64 , #SparseVector >,
86- %out: tensor <8 xi64 >) -> tensor <8 xi64 > {
86+ func.func @sparse_index_1d_disj (%arga: tensor <8 xi64 , #SparseVector >)
87+ -> tensor <8 xi64 > {
88+ %out = tensor.empty () : tensor <8 xi64 >
8789 %r = linalg.generic #trait_1d
8890 ins (%arga: tensor <8 xi64 , #SparseVector >)
8991 outs (%out: tensor <8 xi64 >) {
@@ -99,8 +101,9 @@ module {
99101 //
100102 // Kernel that uses indices in the index notation (conjunction).
101103 //
102- func.func @sparse_index_2d_conj (%arga: tensor <3 x4 xi64 , #SparseMatrix >,
103- %out: tensor <3 x4 xi64 >) -> tensor <3 x4 xi64 > {
104+ func.func @sparse_index_2d_conj (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
105+ -> tensor <3 x4 xi64 > {
106+ %out = tensor.empty () : tensor <3 x4 xi64 >
104107 %r = linalg.generic #trait_2d
105108 ins (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
106109 outs (%out: tensor <3 x4 xi64 >) {
@@ -119,8 +122,9 @@ module {
119122 //
120123 // Kernel that uses indices in the index notation (disjunction).
121124 //
122- func.func @sparse_index_2d_disj (%arga: tensor <3 x4 xi64 , #SparseMatrix >,
123- %out: tensor <3 x4 xi64 >) -> tensor <3 x4 xi64 > {
125+ func.func @sparse_index_2d_disj (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
126+ -> tensor <3 x4 xi64 > {
127+ %out = tensor.empty () : tensor <3 x4 xi64 >
124128 %r = linalg.generic #trait_2d
125129 ins (%arga: tensor <3 x4 xi64 , #SparseMatrix >)
126130 outs (%out: tensor <3 x4 xi64 >) {
@@ -161,20 +165,15 @@ module {
161165 [ 1 , 1 , 3 , 4 ] ]> : tensor <3 x4 xi64 >
162166 %dm = sparse_tensor.convert %m2 : tensor <3 x4 xi64 > to tensor <3 x4 xi64 , #SparseMatrix >
163167
164- // Setup out tensors.
165- // Note: Constants bufferize to read-only buffers.
166- %init_8 = tensor.empty () : tensor <8 xi64 >
167- %init_3_4 = tensor.empty () : tensor <3 x4 xi64 >
168-
169168 // Call the kernels.
170- %0 = call @sparse_index_1d_conj (%sv , %init_8 ) : (tensor <8 xi64 , #SparseVector >, tensor < 8 x i64 >) -> tensor <8 xi64 >
171- %1 = call @sparse_index_1d_disj (%sv , %init_8 ) : (tensor <8 xi64 , #SparseVector >, tensor < 8 x i64 >) -> tensor <8 xi64 >
172- %2 = call @sparse_index_1d_conj (%dv , %init_8 ) : (tensor <8 xi64 , #SparseVector >, tensor < 8 x i64 >) -> tensor <8 xi64 >
173- %3 = call @sparse_index_1d_disj (%dv , %init_8 ) : (tensor <8 xi64 , #SparseVector >, tensor < 8 x i64 >) -> tensor <8 xi64 >
174- %4 = call @sparse_index_2d_conj (%sm , %init_3_4 ) : (tensor <3 x4 xi64 , #SparseMatrix >, tensor < 3 x 4 x i64 >) -> tensor <3 x4 xi64 >
175- %5 = call @sparse_index_2d_disj (%sm , %init_3_4 ) : (tensor <3 x4 xi64 , #SparseMatrix >, tensor < 3 x 4 x i64 >) -> tensor <3 x4 xi64 >
176- %6 = call @sparse_index_2d_conj (%dm , %init_3_4 ) : (tensor <3 x4 xi64 , #SparseMatrix >, tensor < 3 x 4 x i64 >) -> tensor <3 x4 xi64 >
177- %7 = call @sparse_index_2d_disj (%dm , %init_3_4 ) : (tensor <3 x4 xi64 , #SparseMatrix >, tensor < 3 x 4 x i64 >) -> tensor <3 x4 xi64 >
169+ %0 = call @sparse_index_1d_conj (%sv ) : (tensor <8 xi64 , #SparseVector >) -> tensor <8 xi64 >
170+ %1 = call @sparse_index_1d_disj (%sv ) : (tensor <8 xi64 , #SparseVector >) -> tensor <8 xi64 >
171+ %2 = call @sparse_index_1d_conj (%dv ) : (tensor <8 xi64 , #SparseVector >) -> tensor <8 xi64 >
172+ %3 = call @sparse_index_1d_disj (%dv ) : (tensor <8 xi64 , #SparseVector >) -> tensor <8 xi64 >
173+ %4 = call @sparse_index_2d_conj (%sm ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 >
174+ %5 = call @sparse_index_2d_disj (%sm ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 >
175+ %6 = call @sparse_index_2d_conj (%dm ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 >
176+ %7 = call @sparse_index_2d_disj (%dm ) : (tensor <3 x4 xi64 , #SparseMatrix >) -> tensor <3 x4 xi64 >
178177
179178 //
180179 // Verify result.
0 commit comments