Skip to content

Commit d765bb9

Browse files
committed
Partial changes for different files for kernel and input
1 parent 7f9d00f commit d765bb9

File tree

6 files changed

+312
-48
lines changed

6 files changed

+312
-48
lines changed

generic_solver/cublas_example.mlir

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,6 @@ module {
33
// Define a collection of kernel operation definitions
44
kernel.defn_collection {
55

6-
// GEMM operation definition with linalg.generic representation
7-
kernel.defn @simple_gemm_linalg(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
8-
// Implementation using linalg.generic
9-
%result = linalg.generic {
10-
indexing_maps = [
11-
affine_map<(i, j, k) -> (i, k)>, // A(i,k)
12-
affine_map<(i, j, k) -> (k, j)>, // B(k,j)
13-
affine_map<(i, j, k) -> (i, j)> // C(i,j)
14-
],
15-
iterator_types = ["parallel", "parallel", "reduction"]
16-
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
17-
outs(%C : tensor<?x?xf32>) {
18-
^bb0(%a: f32, %b: f32, %c: f32):
19-
%product = arith.mulf %a, %b : f32
20-
%result = arith.addf %product, %c : f32
21-
linalg.yield %result : f32
22-
} -> tensor<?x?xf32>
23-
kernel.yield %result : tensor<?x?xf32>
24-
}
25-
266
// GEMM operation definition with arbitrary code implementation
277
kernel.defn @gemm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) {
288
// This could include arbitrary code to implement the GEMM operation
@@ -89,6 +69,27 @@ module {
8969
} -> tensor<?x?x?xf32>
9070
kernel.yield
9171
}
72+
73+
// GEMM operation definition with linalg.generic representation
74+
kernel.defn @simple_gemm_linalg(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
75+
// Implementation using linalg.generic
76+
%result = linalg.generic {
77+
indexing_maps = [
78+
affine_map<(i, j, k) -> (i, k)>, // A(i,k)
79+
affine_map<(i, j, k) -> (k, j)>, // B(k,j)
80+
affine_map<(i, j, k) -> (i, j)> // C(i,j)
81+
],
82+
iterator_types = ["parallel", "parallel", "reduction"]
83+
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
84+
outs(%C : tensor<?x?xf32>) {
85+
^bb0(%a: f32, %b: f32, %c: f32):
86+
%product = arith.mulf %a, %b : f32
87+
%result = arith.addf %product, %c : f32
88+
linalg.yield %result : f32
89+
} -> tensor<?x?xf32>
90+
kernel.yield %result : tensor<?x?xf32>
91+
}
92+
9293

9394
// Index of maximum absolute value operation definition with arbitrary code
9495
kernel.defn @iamax(%X: tensor<?xf32>) -> tensor<i32> {
@@ -195,26 +196,6 @@ module {
195196
kernel.yield %result : tensor<f32>
196197
}
197198

198-
//Func that uses simple gemm
199-
func.func @simple_gemm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
200-
// Implementation using linalg.generic
201-
%result = linalg.generic {
202-
indexing_maps = [
203-
affine_map<(i, j, k) -> (i, k)>, // A(i,k)
204-
affine_map<(i, j, k) -> (k, j)>, // B(k,j)
205-
affine_map<(i, j, k) -> (i, j)> // C(i,j)
206-
],
207-
iterator_types = ["parallel", "parallel", "reduction"]
208-
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
209-
outs(%C : tensor<?x?xf32>) {
210-
^bb0(%a: f32, %b: f32, %c: f32):
211-
%product = arith.mulf %a, %b : f32
212-
%result = arith.addf %product, %c : f32
213-
linalg.yield %result : f32
214-
} -> tensor<?x?xf32>
215-
return %result : tensor<?x?xf32>
216-
}
217-
218199
// Mathematical definitions (commented, for reference)
219200
// kernel.defn @gemm(...) {
220201
// C(i,j) += alpha * A(i,k) * B(k,j);
@@ -236,4 +217,25 @@ module {
236217
// result = sum_i |x_i|;
237218
// }
238219
}
220+
221+
//Func that uses simple gemm
222+
func.func @simple_gemm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
223+
// Implementation using linalg.generic
224+
%result = linalg.generic {
225+
indexing_maps = [
226+
affine_map<(i, j, k) -> (i, k)>, // A(i,k)
227+
affine_map<(i, j, k) -> (k, j)>, // B(k,j)
228+
affine_map<(i, j, k) -> (i, j)> // C(i,j)
229+
],
230+
iterator_types = ["parallel", "parallel", "reduction"]
231+
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
232+
outs(%C : tensor<?x?xf32>) {
233+
^bb0(%a: f32, %b: f32, %c: f32):
234+
%product = arith.mulf %a, %b : f32
235+
%result = arith.addf %product, %c : f32
236+
linalg.yield %result : f32
237+
} -> tensor<?x?xf32>
238+
return %result : tensor<?x?xf32>
239+
}
240+
239241
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Kernel Library - Reusable kernel definitions
2+
// This file contains a collection of kernel definitions that can be loaded
3+
// by the linalg-to-kernel pass and applied to different MLIR modules.
4+
5+
module {
6+
// Collection of kernel operation definitions
7+
kernel.defn_collection {
8+
9+
// Simple GEMM operation definition with linalg.generic representation
10+
kernel.defn @simple_gemm_linalg(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
11+
// Simple matrix multiplication: C = A * B + C
12+
%result = linalg.generic {
13+
indexing_maps = [
14+
affine_map<(d0, d1, d2) -> (d0, d2)>,
15+
affine_map<(d0, d1, d2) -> (d2, d1)>,
16+
affine_map<(d0, d1, d2) -> (d0, d1)>
17+
],
18+
iterator_types = ["parallel", "parallel", "reduction"]
19+
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
20+
outs(%C : tensor<?x?xf32>) {
21+
^bb0(%a: f32, %b: f32, %c: f32):
22+
%product = arith.mulf %a, %b : f32
23+
%result = arith.addf %product, %c : f32
24+
linalg.yield %result : f32
25+
} -> tensor<?x?xf32>
26+
kernel.yield %result : tensor<?x?xf32>
27+
}
28+
29+
// Scaled GEMM operation definition with alpha and beta coefficients
30+
kernel.defn @gemm_linalg(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
31+
%alpha = arith.constant 1.0 : f32
32+
%beta = arith.constant 0.0 : f32
33+
34+
// GEMM with scaling: C = alpha * A * B + beta * C
35+
%result = linalg.generic {
36+
indexing_maps = [
37+
affine_map<(d0, d1, d2) -> (d0, d2)>,
38+
affine_map<(d0, d1, d2) -> (d2, d1)>,
39+
affine_map<(d0, d1, d2) -> (d0, d1)>
40+
],
41+
iterator_types = ["parallel", "parallel", "reduction"]
42+
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
43+
outs(%C : tensor<?x?xf32>) {
44+
^bb0(%a: f32, %b: f32, %c: f32):
45+
%product = arith.mulf %a, %b : f32
46+
%scaled = arith.mulf %product, %alpha : f32
47+
%scaled_c = arith.mulf %c, %beta : f32
48+
%result = arith.addf %scaled, %scaled_c : f32
49+
linalg.yield %result : f32
50+
} -> tensor<?x?xf32>
51+
kernel.yield %result : tensor<?x?xf32>
52+
}
53+
54+
// Sum of absolute values operation (ASUM)
55+
kernel.defn @asum_linalg(%X: tensor<?xf32>) -> tensor<f32> {
56+
%c0 = arith.constant 0.0 : f32
57+
%init = tensor.empty() : tensor<f32>
58+
%fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<f32>) -> tensor<f32>
59+
60+
// Sum of absolute values: result = sum_i |x_i|
61+
%result = linalg.generic {
62+
indexing_maps = [
63+
affine_map<(d0) -> (d0)>,
64+
affine_map<(d0) -> ()>
65+
],
66+
iterator_types = ["reduction"]
67+
} ins(%X : tensor<?xf32>)
68+
outs(%fill : tensor<f32>) {
69+
^bb0(%in: f32, %out: f32):
70+
%abs_val = math.absf %in : f32
71+
%result = arith.addf %abs_val, %out : f32
72+
linalg.yield %result : f32
73+
} -> tensor<f32>
74+
kernel.yield %result : tensor<f32>
75+
}
76+
77+
// Vector dot product
78+
kernel.defn @dot_linalg(%X: tensor<?xf32>, %Y: tensor<?xf32>) -> tensor<f32> {
79+
%c0 = arith.constant 0.0 : f32
80+
%init = tensor.empty() : tensor<f32>
81+
%fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<f32>) -> tensor<f32>
82+
83+
// Dot product: result = sum_i x_i * y_i
84+
%result = linalg.generic {
85+
indexing_maps = [
86+
affine_map<(d0) -> (d0)>,
87+
affine_map<(d0) -> (d0)>,
88+
affine_map<(d0) -> ()>
89+
],
90+
iterator_types = ["reduction"]
91+
} ins(%X, %Y : tensor<?xf32>, tensor<?xf32>)
92+
outs(%fill : tensor<f32>) {
93+
^bb0(%x: f32, %y: f32, %out: f32):
94+
%product = arith.mulf %x, %y : f32
95+
%result = arith.addf %product, %out : f32
96+
linalg.yield %result : f32
97+
} -> tensor<f32>
98+
kernel.yield %result : tensor<f32>
99+
}
100+
}
101+
}
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Test input file - contains linalg.generic operations to be matched
2+
// This file does NOT contain kernel.defn_collection - those will be loaded externally
3+
4+
module {
5+
// Function that performs simple matrix multiplication
6+
func.func @simple_gemm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
7+
// This linalg.generic should match @simple_gemm_linalg from kernel_library.mlir
8+
%result = linalg.generic {
9+
indexing_maps = [
10+
affine_map<(d0, d1, d2) -> (d0, d2)>,
11+
affine_map<(d0, d1, d2) -> (d2, d1)>,
12+
affine_map<(d0, d1, d2) -> (d0, d1)>
13+
],
14+
iterator_types = ["parallel", "parallel", "reduction"]
15+
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
16+
outs(%C : tensor<?x?xf32>) {
17+
^bb0(%a: f32, %b: f32, %c: f32):
18+
%product = arith.mulf %a, %b : f32
19+
%result = arith.addf %product, %c : f32
20+
linalg.yield %result : f32
21+
} -> tensor<?x?xf32>
22+
return %result : tensor<?x?xf32>
23+
}
24+
25+
// Function that computes sum of absolute values
26+
func.func @compute_asum(%X: tensor<?xf32>) -> tensor<f32> {
27+
%c0 = arith.constant 0.0 : f32
28+
%init = tensor.empty() : tensor<f32>
29+
%fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<f32>) -> tensor<f32>
30+
31+
// This linalg.generic should match @asum_linalg from kernel_library.mlir
32+
%result = linalg.generic {
33+
indexing_maps = [
34+
affine_map<(d0) -> (d0)>,
35+
affine_map<(d0) -> ()>
36+
],
37+
iterator_types = ["reduction"]
38+
} ins(%X : tensor<?xf32>)
39+
outs(%fill : tensor<f32>) {
40+
^bb0(%in: f32, %out: f32):
41+
%abs_val = math.absf %in : f32
42+
%result = arith.addf %abs_val, %out : f32
43+
linalg.yield %result : f32
44+
} -> tensor<f32>
45+
return %result : tensor<f32>
46+
}
47+
48+
// Function that computes dot product
49+
func.func @compute_dot(%X: tensor<?xf32>, %Y: tensor<?xf32>) -> tensor<f32> {
50+
%c0 = arith.constant 0.0 : f32
51+
%init = tensor.empty() : tensor<f32>
52+
%fill = linalg.fill ins(%c0 : f32) outs(%init : tensor<f32>) -> tensor<f32>
53+
54+
// This linalg.generic should match @dot_linalg from kernel_library.mlir
55+
%result = linalg.generic {
56+
indexing_maps = [
57+
affine_map<(d0) -> (d0)>,
58+
affine_map<(d0) -> (d0)>,
59+
affine_map<(d0) -> ()>
60+
],
61+
iterator_types = ["reduction"]
62+
} ins(%X, %Y : tensor<?xf32>, tensor<?xf32>)
63+
outs(%fill : tensor<f32>) {
64+
^bb0(%x: f32, %y: f32, %out: f32):
65+
%product = arith.mulf %x, %y : f32
66+
%result = arith.addf %product, %out : f32
67+
linalg.yield %result : f32
68+
} -> tensor<f32>
69+
return %result : tensor<f32>
70+
}
71+
}

include/polygeist/Passes/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ createGpuSerializeToHsacoPass(StringRef arch, StringRef features,
7474
std::string rocmPath, bool outputIntermediate);
7575

7676
std::unique_ptr<Pass> createLinalgToKernelPass();
77+
std::unique_ptr<Pass> createLinalgToKernelPass(const std::string& kernelLibraryPath);
7778

7879
void registerGpuSerializeToCubinPass();
7980
void registerGpuSerializeToHsacoPass();

include/polygeist/Passes/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ def LinalgToKernel : Pass<"linalg-to-kernel", "mlir::ModuleOp"> {
294294
"tensor::TensorDialect",
295295
"arith::ArithDialect",
296296
];
297+
let options = [
298+
Option<"kernelLibraryPath", "kernel-library-path", "std::string",
299+
/*default=*/"\"\"",
300+
"Path to external MLIR file containing kernel.defn_collection definitions. "
301+
"If empty, looks for kernel.defn_collection in the input module.">
302+
];
297303
}
298304

299305
def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> {

0 commit comments

Comments
 (0)