Skip to content

Commit 7f9d00f

Browse files
committed
Working pattern matching and replacement for linalg generics
1 parent 6a67379 commit 7f9d00f

File tree

4 files changed

+278
-28
lines changed

4 files changed

+278
-28
lines changed

generic_solver/cublas_example.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,27 @@
22
module {
33
// Define a collection of kernel operation definitions
44
kernel.defn_collection {
5+
6+
// GEMM operation definition with linalg.generic representation
7+
kernel.defn @simple_gemm_linalg(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
8+
// Implementation using linalg.generic
9+
%result = linalg.generic {
10+
indexing_maps = [
11+
affine_map<(i, j, k) -> (i, k)>, // A(i,k)
12+
affine_map<(i, j, k) -> (k, j)>, // B(k,j)
13+
affine_map<(i, j, k) -> (i, j)> // C(i,j)
14+
],
15+
iterator_types = ["parallel", "parallel", "reduction"]
16+
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
17+
outs(%C : tensor<?x?xf32>) {
18+
^bb0(%a: f32, %b: f32, %c: f32):
19+
%product = arith.mulf %a, %b : f32
20+
%result = arith.addf %product, %c : f32
21+
linalg.yield %result : f32
22+
} -> tensor<?x?xf32>
23+
kernel.yield %result : tensor<?x?xf32>
24+
}
25+
526
// GEMM operation definition with arbitrary code implementation
627
kernel.defn @gemm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) {
728
// This could include arbitrary code to implement the GEMM operation
@@ -173,6 +194,26 @@ module {
173194
} -> tensor<f32>
174195
kernel.yield %result : tensor<f32>
175196
}
197+
198+
//Func that uses simple gemm
199+
func.func @simple_gemm(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
200+
// Implementation using linalg.generic
201+
%result = linalg.generic {
202+
indexing_maps = [
203+
affine_map<(i, j, k) -> (i, k)>, // A(i,k)
204+
affine_map<(i, j, k) -> (k, j)>, // B(k,j)
205+
affine_map<(i, j, k) -> (i, j)> // C(i,j)
206+
],
207+
iterator_types = ["parallel", "parallel", "reduction"]
208+
} ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
209+
outs(%C : tensor<?x?xf32>) {
210+
^bb0(%a: f32, %b: f32, %c: f32):
211+
%product = arith.mulf %a, %b : f32
212+
%result = arith.addf %product, %c : f32
213+
linalg.yield %result : f32
214+
} -> tensor<?x?xf32>
215+
return %result : tensor<?x?xf32>
216+
}
176217

177218
// Mathematical definitions (commented, for reference)
178219
// kernel.defn @gemm(...) {

include/polygeist/Kernel/KernelOps.td

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def Kernel_DefnOp : Kernel_Op<"defn", [
6868
OptionalAttr<DictArrayAttr>:$arg_attrs,
6969
OptionalAttr<DictArrayAttr>:$res_attrs
7070
);
71-
71+
7272
let regions = (region AnyRegion:$body);
7373

7474
let builders = [OpBuilder<(ins
@@ -99,6 +99,83 @@ def Kernel_DefnOp : Kernel_Op<"defn", [
9999
}];
100100
}
101101

102+
//===----------------------------------------------------------------------===//
103+
// LaunchOp
104+
//===----------------------------------------------------------------------===//
105+
106+
def Kernel_LaunchOp : Kernel_Op<"launch",
107+
[CallOpInterface, MemRefsNormalizable,
108+
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
109+
let summary = "kernel launch operation";
110+
let description = [{
111+
The `kernel.launch` operation represents a launch of a kernel that is
112+
within the same symbol scope as the launch. The operands and result types of
113+
the launch must match the specified kernel type. The kernel is encoded as a
114+
symbol reference attribute named "kernel".
115+
116+
Example:
117+
118+
```mlir
119+
%result = kernel.launch @custom_gemm(%A, %B, %C, %alpha) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>, f32) -> tensor<f32>
120+
```
121+
}];
122+
123+
let arguments = (ins FlatSymbolRefAttr:$kernel, Variadic<AnyType>:$operands);
124+
let results = (outs Variadic<AnyType>);
125+
126+
let builders = [
127+
OpBuilder<(ins "DefnOp":$kernel, CArg<"ValueRange", "{}">:$operands), [{
128+
$_state.addOperands(operands);
129+
$_state.addAttribute("kernel", SymbolRefAttr::get(kernel));
130+
$_state.addTypes(kernel.getFunctionType().getResults());
131+
}]>,
132+
OpBuilder<(ins "SymbolRefAttr":$kernel, "TypeRange":$results,
133+
CArg<"ValueRange", "{}">:$operands), [{
134+
$_state.addOperands(operands);
135+
$_state.addAttribute("kernel", kernel);
136+
$_state.addTypes(results);
137+
}]>,
138+
OpBuilder<(ins "StringAttr":$kernel, "TypeRange":$results,
139+
CArg<"ValueRange", "{}">:$operands), [{
140+
build($_builder, $_state, SymbolRefAttr::get(kernel), results, operands);
141+
}]>,
142+
OpBuilder<(ins "StringRef":$kernel, "TypeRange":$results,
143+
CArg<"ValueRange", "{}">:$operands), [{
144+
build($_builder, $_state, StringAttr::get($_builder.getContext(), kernel),
145+
results, operands);
146+
}]>];
147+
148+
let extraClassDeclaration = [{
149+
FunctionType getKernelType();
150+
151+
/// Get the argument operands to the launched kernel.
152+
operand_range getArgOperands() {
153+
return {arg_operand_begin(), arg_operand_end()};
154+
}
155+
156+
MutableOperandRange getArgOperandsMutable() {
157+
return getOperandsMutable();
158+
}
159+
160+
operand_iterator arg_operand_begin() { return operand_begin(); }
161+
operand_iterator arg_operand_end() { return operand_end(); }
162+
163+
/// Return the kernel of this operation.
164+
CallInterfaceCallable getCallableForCallee() {
165+
return (*this)->getAttrOfType<SymbolRefAttr>("kernel");
166+
}
167+
168+
/// Set the kernel for this operation.
169+
void setCalleeFromCallable(CallInterfaceCallable callee) {
170+
(*this)->setAttr("kernel", callee.get<SymbolRefAttr>());
171+
}
172+
}];
173+
174+
let assemblyFormat = [{
175+
$kernel `(` $operands `)` attr-dict `:` functional-type($operands, results)
176+
}];
177+
}
178+
102179
def Kernel_YieldOp : Kernel_Op<"yield", [Pure, HasParent<"DefnOp">,
103180
MemRefsNormalizable, ReturnLike, Terminator]> {
104181
let summary = "Terminator for kernel.defn operation";

lib/polygeist/Kernel/KernelOps.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,64 @@ LogicalResult YieldOp::verify() {
8484
return success();
8585
}
8686

87+
//===----------------------------------------------------------------------===//
88+
// LaunchOp
89+
//===----------------------------------------------------------------------===//
90+
91+
FunctionType LaunchOp::getKernelType() {
92+
// Get the kernel symbol reference
93+
auto kernelAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("kernel");
94+
if (!kernelAttr)
95+
return nullptr;
96+
97+
// Look up the kernel DefnOp in the symbol table
98+
auto *symbolTableOp = (*this)->getParentWithTrait<OpTrait::SymbolTable>();
99+
if (!symbolTableOp)
100+
return nullptr;
101+
102+
auto kernelOp = dyn_cast_or_null<DefnOp>(
103+
SymbolTable::lookupSymbolIn(symbolTableOp, kernelAttr));
104+
if (!kernelOp)
105+
return nullptr;
106+
107+
return kernelOp.getFunctionType();
108+
}
109+
110+
LogicalResult LaunchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
111+
// Check that the kernel attribute was specified.
112+
auto kernelAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("kernel");
113+
if (!kernelAttr)
114+
return emitOpError("requires a 'kernel' symbol reference attribute");
115+
116+
// Check that the kernel symbol exists and is a DefnOp.
117+
auto kernelOp = symbolTable.lookupNearestSymbolFrom<DefnOp>(*this, kernelAttr);
118+
if (!kernelOp)
119+
return emitOpError() << "'" << kernelAttr.getValue()
120+
<< "' does not reference a valid kernel";
121+
122+
// Verify that the operand and result types match the kernel signature.
123+
auto kernelType = kernelOp.getFunctionType();
124+
if (kernelType.getNumInputs() != getNumOperands())
125+
return emitOpError("incorrect number of operands for kernel");
126+
127+
for (unsigned i = 0, e = kernelType.getNumInputs(); i != e; ++i)
128+
if (getOperand(i).getType() != kernelType.getInput(i))
129+
return emitOpError("operand type mismatch: expected operand type ")
130+
<< kernelType.getInput(i) << ", but provided "
131+
<< getOperand(i).getType() << " for operand number " << i;
132+
133+
if (kernelType.getNumResults() != getNumResults())
134+
return emitOpError("incorrect number of results for kernel");
135+
136+
for (unsigned i = 0, e = kernelType.getNumResults(); i != e; ++i)
137+
if (getResult(i).getType() != kernelType.getResult(i))
138+
return emitOpError("result type mismatch: expected result type ")
139+
<< kernelType.getResult(i) << ", but provided "
140+
<< getResult(i).getType() << " for result number " << i;
141+
142+
return success();
143+
}
144+
87145
//===----------------------------------------------------------------------===//
88146
// TableGen'd op definitions
89147
//===----------------------------------------------------------------------===//

lib/polygeist/Passes/LinalgToKernel.cpp

Lines changed: 101 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) {
8181
return false;
8282

8383
for (auto typePair : llvm::zip(firstTypes, secondTypes)) {
84-
auto firstType = std::get<0>(typePair).cast<StringAttr>().getValue();
85-
auto secondType = std::get<1>(typePair).cast<StringAttr>().getValue();
84+
auto firstType = std::get<0>(typePair).cast<linalg::IteratorTypeAttr>().getValue();
85+
auto secondType = std::get<1>(typePair).cast<linalg::IteratorTypeAttr>().getValue();
8686

8787
if (firstType != secondType)
8888
return false;
@@ -102,32 +102,43 @@ FailureOr<StringRef> matchGenericWithDefn(
102102
unsigned numInputs = genericOp.getNumDpsInputs();
103103
unsigned numOutputs = genericOp.getNumDpsInits();
104104

105+
// Variables to capture the match result
106+
StringRef matchedOpName;
107+
108+
SmallVector<kernel::DefnOp> defnOps;
109+
110+
collectionOp.walk([&](kernel::DefnOp defnOp) {
111+
defnOps.push_back(defnOp);
112+
});
113+
114+
bool foundMatch = false;
115+
105116
// Walk through each defn in the collection
106-
for (Operation &op : collectionOp.getDefns()) {
107-
auto defnOp = cast<kernel::DefnOp>(op);
108-
StringRef opName = defnOp.getSymName();
117+
for (auto defnOp : defnOps) {
109118

119+
StringRef opName = defnOp.getSymName();
110120
// Check for linalg.generic in the defn's body
111-
bool foundMatch = false;
112-
defnOp.getBody().walk([&](GenericOp candidateOp) {
113-
// Skip if already found a match
114-
if (foundMatch)
115-
return;
116-
117-
// Check if this linalg.generic matches our target
118-
if (candidateOp.getNumDpsInputs() == numInputs &&
119-
candidateOp.getNumDpsInits() == numOutputs &&
120-
areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) &&
121-
areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) &&
122-
areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) {
123-
foundMatch = true;
124-
}
121+
GenericOp candidateOp;
122+
123+
defnOp.walk([&](GenericOp genericOp) {
124+
candidateOp = genericOp; //TODO: Add checks to make sure there is only single linalg.generic in the defn
125125
});
126126

127-
if (foundMatch)
128-
return opName;
127+
// Check if this linalg.generic matches our target
128+
if (candidateOp.getNumDpsInputs() == numInputs &&
129+
candidateOp.getNumDpsInits() == numOutputs &&
130+
areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) &&
131+
areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) &&
132+
areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) {
133+
foundMatch = true;
134+
matchedOpName = opName;
135+
}
136+
137+
if (foundMatch) {
138+
return matchedOpName;
139+
}
129140
}
130-
141+
131142
return failure();
132143
}
133144

@@ -140,19 +151,82 @@ class LinalgGenericToKernelPattern : public OpRewritePattern<GenericOp> {
140151

141152
LogicalResult matchAndRewrite(GenericOp genericOp,
142153
PatternRewriter &rewriter) const override {
154+
155+
auto module = genericOp->getParentOfType<ModuleOp>();
156+
//Check if the parent of the generic op is a kernel.defn
157+
if (auto parentOp = genericOp->getParentOp()) {
158+
if (isa<kernel::DefnOp>(parentOp)) {
159+
return failure();
160+
}
161+
}
162+
143163
// Try to match with a defn in the collection
144164
auto matchResult = matchGenericWithDefn(genericOp, collectionOp);
145165
if (failed(matchResult))
146166
return failure();
147167

148168
StringRef opName = *matchResult;
149169

150-
// For now, just emit a diagnostic indicating we found a match
151-
// In the future, this would create the appropriate kernel operation
152-
genericOp.emitRemark() << "Matched linalg.generic with kernel pattern: " << opName;
170+
// Find the matched kernel.defn operation
171+
kernel::DefnOp matchedDefnOp;
172+
// Use const_cast to work around the const issue
173+
const_cast<kernel::DefnCollectionOp&>(collectionOp).walk([&](kernel::DefnOp defnOp) {
174+
if (defnOp.getSymName() == opName) {
175+
matchedDefnOp = defnOp;
176+
return WalkResult::interrupt();
177+
}
178+
return WalkResult::advance();
179+
});
180+
181+
if (!matchedDefnOp) {
182+
return failure();
183+
}
184+
185+
// Check if the kernel.defn already exists in the target module
186+
kernel::DefnOp existingDefn;
187+
module.walk([&](kernel::DefnOp defnOp) {
188+
if (defnOp.getSymName() == opName) {
189+
// Check if this defn is inside a defn_collection (template) or at module level (callable)
190+
if (!defnOp->getParentOfType<kernel::DefnCollectionOp>()) {
191+
existingDefn = defnOp;
192+
return WalkResult::interrupt();
193+
}
194+
}
195+
return WalkResult::advance();
196+
});
197+
198+
// If the kernel.defn doesn't exist in the module, copy it
199+
if (!existingDefn) {
200+
// Clone the matched kernel.defn operation
201+
rewriter.setInsertionPointToStart(module.getBody());
202+
auto clonedDefn = rewriter.clone(*matchedDefnOp.getOperation());
203+
(void)clonedDefn; // Suppress unused variable warning
204+
}
205+
206+
// Create kernel.launch operation to replace the genericOp
207+
Location loc = genericOp.getLoc();
208+
209+
// Set insertion point to the genericOp location
210+
rewriter.setInsertionPoint(genericOp);
211+
212+
// Get operands from the generic operation (inputs and outputs)
213+
SmallVector<Value> operands;
214+
operands.append(genericOp.getInputs().begin(), genericOp.getInputs().end());
215+
operands.append(genericOp.getOutputs().begin(), genericOp.getOutputs().end());
216+
217+
// Get result types from the generic operation
218+
TypeRange resultTypes = genericOp.getResultTypes();
219+
220+
// Create the kernel.launch operation
221+
auto launchOp = rewriter.create<kernel::LaunchOp>(
222+
loc,
223+
resultTypes,
224+
opName,
225+
operands
226+
);
153227

154-
// TODO: Create the appropriate kernel operation based on the matched pattern
155-
// This would require implementing kernel operations in the kernel dialect
228+
// Replace the generic operation with the launch operation
229+
rewriter.replaceOp(genericOp, launchOp.getResults());
156230

157231
return success();
158232
}

0 commit comments

Comments
 (0)