@@ -81,8 +81,8 @@ bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) {
8181 return false ;
8282
8383 for (auto typePair : llvm::zip (firstTypes, secondTypes)) {
84- auto firstType = std::get<0 >(typePair).cast <StringAttr >().getValue ();
85- auto secondType = std::get<1 >(typePair).cast <StringAttr >().getValue ();
84+ auto firstType = std::get<0 >(typePair).cast <linalg::IteratorTypeAttr >().getValue ();
85+ auto secondType = std::get<1 >(typePair).cast <linalg::IteratorTypeAttr >().getValue ();
8686
8787 if (firstType != secondType)
8888 return false ;
@@ -102,32 +102,43 @@ FailureOr<StringRef> matchGenericWithDefn(
102102 unsigned numInputs = genericOp.getNumDpsInputs ();
103103 unsigned numOutputs = genericOp.getNumDpsInits ();
104104
105+ // Variables to capture the match result
106+ StringRef matchedOpName;
107+
108+ SmallVector<kernel::DefnOp> defnOps;
109+
110+ collectionOp.walk ([&](kernel::DefnOp defnOp) {
111+ defnOps.push_back (defnOp);
112+ });
113+
114+ bool foundMatch = false ;
115+
105116 // Walk through each defn in the collection
106- for (Operation &op : collectionOp.getDefns ()) {
107- auto defnOp = cast<kernel::DefnOp>(op);
108- StringRef opName = defnOp.getSymName ();
117+ for (auto defnOp : defnOps) {
109118
119+ StringRef opName = defnOp.getSymName ();
110120 // Check for linalg.generic in the defn's body
111- bool foundMatch = false ;
112- defnOp.getBody ().walk ([&](GenericOp candidateOp) {
113- // Skip if already found a match
114- if (foundMatch)
115- return ;
116-
117- // Check if this linalg.generic matches our target
118- if (candidateOp.getNumDpsInputs () == numInputs &&
119- candidateOp.getNumDpsInits () == numOutputs &&
120- areIndexingMapsEquivalent (candidateOp.getIndexingMapsAttr (), indexingMaps) &&
121- areIteratorTypesEquivalent (candidateOp.getIteratorTypesAttr (), iteratorTypes) &&
122- areRegionsEquivalent (candidateOp.getRegion (), genericOp.getRegion ())) {
123- foundMatch = true ;
124- }
121+ GenericOp candidateOp;
122+
123+ defnOp.walk ([&](GenericOp genericOp) {
124+ candidateOp = genericOp; // TODO: Add checks to make sure there is only single linalg.generic in the defn
125125 });
126126
127- if (foundMatch)
128- return opName;
127+ // Check if this linalg.generic matches our target
128+ if (candidateOp.getNumDpsInputs () == numInputs &&
129+ candidateOp.getNumDpsInits () == numOutputs &&
130+ areIndexingMapsEquivalent (candidateOp.getIndexingMapsAttr (), indexingMaps) &&
131+ areIteratorTypesEquivalent (candidateOp.getIteratorTypesAttr (), iteratorTypes) &&
132+ areRegionsEquivalent (candidateOp.getRegion (), genericOp.getRegion ())) {
133+ foundMatch = true ;
134+ matchedOpName = opName;
135+ }
136+
137+ if (foundMatch) {
138+ return matchedOpName;
139+ }
129140 }
130-
141+
131142 return failure ();
132143}
133144
@@ -140,19 +151,82 @@ class LinalgGenericToKernelPattern : public OpRewritePattern<GenericOp> {
140151
141152 LogicalResult matchAndRewrite (GenericOp genericOp,
142153 PatternRewriter &rewriter) const override {
154+
155+ auto module = genericOp->getParentOfType <ModuleOp>();
156+ // Check if the parent of the generic op is a kernel.defn
157+ if (auto parentOp = genericOp->getParentOp ()) {
158+ if (isa<kernel::DefnOp>(parentOp)) {
159+ return failure ();
160+ }
161+ }
162+
143163 // Try to match with a defn in the collection
144164 auto matchResult = matchGenericWithDefn (genericOp, collectionOp);
145165 if (failed (matchResult))
146166 return failure ();
147167
148168 StringRef opName = *matchResult;
149169
150- // For now, just emit a diagnostic indicating we found a match
151- // In the future, this would create the appropriate kernel operation
152- genericOp.emitRemark () << " Matched linalg.generic with kernel pattern: " << opName;
170+ // Find the matched kernel.defn operation
171+ kernel::DefnOp matchedDefnOp;
172+ // Use const_cast to work around the const issue
173+ const_cast <kernel::DefnCollectionOp&>(collectionOp).walk ([&](kernel::DefnOp defnOp) {
174+ if (defnOp.getSymName () == opName) {
175+ matchedDefnOp = defnOp;
176+ return WalkResult::interrupt ();
177+ }
178+ return WalkResult::advance ();
179+ });
180+
181+ if (!matchedDefnOp) {
182+ return failure ();
183+ }
184+
185+ // Check if the kernel.defn already exists in the target module
186+ kernel::DefnOp existingDefn;
187+ module .walk ([&](kernel::DefnOp defnOp) {
188+ if (defnOp.getSymName () == opName) {
189+ // Check if this defn is inside a defn_collection (template) or at module level (callable)
190+ if (!defnOp->getParentOfType <kernel::DefnCollectionOp>()) {
191+ existingDefn = defnOp;
192+ return WalkResult::interrupt ();
193+ }
194+ }
195+ return WalkResult::advance ();
196+ });
197+
198+ // If the kernel.defn doesn't exist in the module, copy it
199+ if (!existingDefn) {
200+ // Clone the matched kernel.defn operation
201+ rewriter.setInsertionPointToStart (module .getBody ());
202+ auto clonedDefn = rewriter.clone (*matchedDefnOp.getOperation ());
203+ (void )clonedDefn; // Suppress unused variable warning
204+ }
205+
206+ // Create kernel.launch operation to replace the genericOp
207+ Location loc = genericOp.getLoc ();
208+
209+ // Set insertion point to the genericOp location
210+ rewriter.setInsertionPoint (genericOp);
211+
212+ // Get operands from the generic operation (inputs and outputs)
213+ SmallVector<Value> operands;
214+ operands.append (genericOp.getInputs ().begin (), genericOp.getInputs ().end ());
215+ operands.append (genericOp.getOutputs ().begin (), genericOp.getOutputs ().end ());
216+
217+ // Get result types from the generic operation
218+ TypeRange resultTypes = genericOp.getResultTypes ();
219+
220+ // Create the kernel.launch operation
221+ auto launchOp = rewriter.create <kernel::LaunchOp>(
222+ loc,
223+ resultTypes,
224+ opName,
225+ operands
226+ );
153227
154- // TODO: Create the appropriate kernel operation based on the matched pattern
155- // This would require implementing kernel operations in the kernel dialect
228+ // Replace the generic operation with the launch operation
229+ rewriter. replaceOp (genericOp, launchOp. getResults ());
156230
157231 return success ();
158232 }
0 commit comments