@@ -70,3 +70,39 @@ func.func @update_notinplace(%argb: tensor<10xf32>, %arga: tensor<10xf32, #SV>)
7070 } -> tensor <10 xf32 >
7171 return %0 , %argb : tensor <10 xf32 >, tensor <10 xf32 >
7272}
73+
74+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
75+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
76+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
77+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
78+ #sparse = #sparse_tensor.encoding <{ map = (d0 , d1 ) -> (d0 : dense , d1 : compressed), posWidth = 64 , crdWidth = 64 }>
79+
80+ // linalg.generic with sparse tensors does not necessarily bufferize to
81+ // element-wise access into the underlying sparse data structures.
82+
83+ // CHECK-LABEL: func @sparse_non_elementwise(
84+ func.func @sparse_non_elementwise (%arg0: tensor <64 x64 xf32 , #sparse >, %arg1: tensor <64 x64 xf32 >, %arg2: tensor <64 x64 xf32 >) -> tensor <64 x64 xf32 > {
85+ %cst = arith.constant 0.000000e+00 : f32
86+ // CHECK: %[[alloc0:.*]] = bufferization.alloc_tensor()
87+ // CHECK: %[[alloc1:.*]] = bufferization.alloc_tensor()
88+ %0 = bufferization.alloc_tensor () : tensor <64 x64 xf32 >
89+ // CHECK: %[[generic0:.*]] = linalg.generic {{.*}} outs(%[[alloc1]] : {{.*}})
90+ %1 = linalg.generic {index ing_maps = [#map ], iterator_types = [" parallel" , " parallel" ]} outs (%0 : tensor <64 x64 xf32 >) {
91+ ^bb0 (%out: f32 ):
92+ linalg.yield %cst : f32
93+ } -> tensor <64 x64 xf32 >
94+ // CHECK: linalg.generic {{.*}} outs(%[[generic0]] : {{.*}})
95+ %2 = linalg.generic {index ing_maps = [#map1 , #map2 , #map3 ], iterator_types = [" parallel" , " parallel" , " reduction" ]} ins (%arg2 , %arg2 : tensor <64 x64 xf32 >, tensor <64 x64 xf32 >) outs (%1 : tensor <64 x64 xf32 >) {
96+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
97+ %4 = arith.mulf %in , %in_0 : f32
98+ %5 = arith.addf %out , %4 : f32
99+ linalg.yield %5 : f32
100+ } -> tensor <64 x64 xf32 >
101+ // CHECK: linalg.generic {{.*}} outs(%[[alloc0]] : {{.*}})
102+ %3 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" ]} ins (%arg0 , %2 : tensor <64 x64 xf32 , #sparse >, tensor <64 x64 xf32 >) outs (%0 : tensor <64 x64 xf32 >) attrs = {sorted = true } {
103+ ^bb0 (%in: f32 , %in_0: f32 , %out: f32 ):
104+ %4 = arith.mulf %in , %in_0 : f32
105+ linalg.yield %4 : f32
106+ } -> tensor <64 x64 xf32 >
107+ return %3 : tensor <64 x64 xf32 >
108+ }
0 commit comments