@@ -3,26 +3,6 @@ module {
33 // Define a collection of kernel operation definitions
44 kernel.defn_collection {
55
6- // GEMM operation definition with linalg.generic representation
7- kernel.defn @simple_gemm_linalg (%A: tensor <?x?xf32 >, %B: tensor <?x?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
8- // Implementation using linalg.generic
9- %result = linalg.generic {
10- indexing_maps = [
11- affine_map <(i , j , k ) -> (i , k )>, // A(i,k)
12- affine_map <(i , j , k ) -> (k , j )>, // B(k,j)
13- affine_map <(i , j , k ) -> (i , j )> // C(i,j)
14- ],
15- iterator_types = [" parallel" , " parallel" , " reduction" ]
16- } ins (%A , %B : tensor <?x?xf32 >, tensor <?x?xf32 >)
17- outs (%C : tensor <?x?xf32 >) {
18- ^bb0 (%a: f32 , %b: f32 , %c: f32 ):
19- %product = arith.mulf %a , %b : f32
20- %result = arith.addf %product , %c : f32
21- linalg.yield %result : f32
22- } -> tensor <?x?xf32 >
23- kernel.yield %result : tensor <?x?xf32 >
24- }
25-
266 // GEMM operation definition with arbitrary code implementation
277 kernel.defn @gemm (%A: tensor <?x?xf32 >, %B: tensor <?x?xf32 >, %C: tensor <?x?xf32 >) {
288 // This could include arbitrary code to implement the GEMM operation
@@ -89,6 +69,27 @@ module {
8969 } -> tensor <?x?x?xf32 >
9070 kernel.yield
9171 }
72+
73+ // GEMM operation definition with linalg.generic representation
74+ kernel.defn @simple_gemm_linalg (%A: tensor <?x?xf32 >, %B: tensor <?x?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
75+ // Implementation using linalg.generic
76+ %result = linalg.generic {
77+ indexing_maps = [
78+ affine_map <(i , j , k ) -> (i , k )>, // A(i,k)
79+ affine_map <(i , j , k ) -> (k , j )>, // B(k,j)
80+ affine_map <(i , j , k ) -> (i , j )> // C(i,j)
81+ ],
82+ iterator_types = [" parallel" , " parallel" , " reduction" ]
83+ } ins (%A , %B : tensor <?x?xf32 >, tensor <?x?xf32 >)
84+ outs (%C : tensor <?x?xf32 >) {
85+ ^bb0 (%a: f32 , %b: f32 , %c: f32 ):
86+ %product = arith.mulf %a , %b : f32
87+ %result = arith.addf %product , %c : f32
88+ linalg.yield %result : f32
89+ } -> tensor <?x?xf32 >
90+ kernel.yield %result : tensor <?x?xf32 >
91+ }
92+
9293
9394 // Index of maximum absolute value operation definition with arbitrary code
9495 kernel.defn @iamax (%X: tensor <?xf32 >) -> tensor <i32 > {
@@ -195,26 +196,6 @@ module {
195196 kernel.yield %result : tensor <f32 >
196197 }
197198
198- //Func that uses simple gemm
199- func.func @simple_gemm (%A: tensor <?x?xf32 >, %B: tensor <?x?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
200- // Implementation using linalg.generic
201- %result = linalg.generic {
202- indexing_maps = [
203- affine_map <(i , j , k ) -> (i , k )>, // A(i,k)
204- affine_map <(i , j , k ) -> (k , j )>, // B(k,j)
205- affine_map <(i , j , k ) -> (i , j )> // C(i,j)
206- ],
207- iterator_types = [" parallel" , " parallel" , " reduction" ]
208- } ins (%A , %B : tensor <?x?xf32 >, tensor <?x?xf32 >)
209- outs (%C : tensor <?x?xf32 >) {
210- ^bb0 (%a: f32 , %b: f32 , %c: f32 ):
211- %product = arith.mulf %a , %b : f32
212- %result = arith.addf %product , %c : f32
213- linalg.yield %result : f32
214- } -> tensor <?x?xf32 >
215- return %result : tensor <?x?xf32 >
216- }
217-
218199 // Mathematical definitions (commented, for reference)
219200 // kernel.defn @gemm(...) {
220201 // C(i,j) += alpha * A(i,k) * B(k,j);
@@ -236,4 +217,25 @@ module {
236217 // result = sum_i |x_i|;
237218 // }
238219 }
220+
221+ //Func that uses simple gemm
222+ func.func @simple_gemm (%A: tensor <?x?xf32 >, %B: tensor <?x?xf32 >, %C: tensor <?x?xf32 >) -> tensor <?x?xf32 > {
223+ // Implementation using linalg.generic
224+ %result = linalg.generic {
225+ indexing_maps = [
226+ affine_map <(i , j , k ) -> (i , k )>, // A(i,k)
227+ affine_map <(i , j , k ) -> (k , j )>, // B(k,j)
228+ affine_map <(i , j , k ) -> (i , j )> // C(i,j)
229+ ],
230+ iterator_types = [" parallel" , " parallel" , " reduction" ]
231+ } ins (%A , %B : tensor <?x?xf32 >, tensor <?x?xf32 >)
232+ outs (%C : tensor <?x?xf32 >) {
233+ ^bb0 (%a: f32 , %b: f32 , %c: f32 ):
234+ %product = arith.mulf %a , %b : f32
235+ %result = arith.addf %product , %c : f32
236+ linalg.yield %result : f32
237+ } -> tensor <?x?xf32 >
238+ return %result : tensor <?x?xf32 >
239+ }
240+
239241}
0 commit comments