Skip to content

Commit 7c204f2

Browse files
committed
Working match for linalg kernel match for gemm
1 parent ca12291 commit 7c204f2

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

lib/polygeist/Passes/LinalgToKernel.cpp

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -605,19 +605,67 @@ class LinalgGenericToKernelPattern : public OpRewritePattern<GenericOp> {
605605

606606
llvm::errs() << "DEBUG: Successfully mapped all kernel arguments, creating kernel.launch\n";
607607

608+
// Get kernel function signature types for casting
609+
auto kernelFuncType = matchedDefnOp.getFunctionType();
610+
auto kernelInputTypes = kernelFuncType.getInputs();
611+
auto kernelResultTypes = kernelFuncType.getResults();
612+
613+
// Cast operands to match kernel signature types if needed
614+
SmallVector<Value> castedOperands;
615+
for (size_t i = 0; i < operands.size(); ++i) {
616+
Value operand = operands[i];
617+
Type expectedType = (i < kernelInputTypes.size()) ? kernelInputTypes[i] : operand.getType();
618+
619+
if (operand.getType() != expectedType) {
620+
// Insert tensor.cast for type conversion
621+
if (isa<RankedTensorType>(operand.getType()) && isa<RankedTensorType>(expectedType)) {
622+
llvm::errs() << "DEBUG: Casting operand " << i << " from " << operand.getType()
623+
<< " to " << expectedType << "\n";
624+
auto castOp = rewriter.create<tensor::CastOp>(loc, expectedType, operand);
625+
castedOperands.push_back(castOp.getResult());
626+
} else {
627+
// For non-tensor types, use the operand as-is
628+
castedOperands.push_back(operand);
629+
}
630+
} else {
631+
castedOperands.push_back(operand);
632+
}
633+
}
634+
608635
// Get result types from the generic operation
609-
TypeRange resultTypes = genericOp.getResultTypes();
636+
TypeRange originalResultTypes = genericOp.getResultTypes();
610637

611-
// Create the kernel.launch operation
638+
// Create the kernel.launch operation with casted operands and kernel result types
612639
auto launchOp = rewriter.create<kernel::LaunchOp>(
613640
loc,
614-
resultTypes,
641+
kernelResultTypes, // Use kernel result types for the launch op
615642
opName,
616-
operands
643+
castedOperands // Use casted operands
617644
);
618645

619-
// Replace the generic operation with the launch operation
620-
rewriter.replaceOp(genericOp, launchOp.getResults());
646+
// Cast results back to original types if needed
647+
SmallVector<Value> finalResults;
648+
for (size_t i = 0; i < launchOp.getResults().size(); ++i) {
649+
Value result = launchOp.getResult(i);
650+
Type originalType = (i < originalResultTypes.size()) ? originalResultTypes[i] : result.getType();
651+
652+
if (result.getType() != originalType) {
653+
// Insert tensor.cast to convert back to original type
654+
if (isa<RankedTensorType>(result.getType()) && isa<RankedTensorType>(originalType)) {
655+
llvm::errs() << "DEBUG: Casting result " << i << " from " << result.getType()
656+
<< " to " << originalType << "\n";
657+
auto castOp = rewriter.create<tensor::CastOp>(loc, originalType, result);
658+
finalResults.push_back(castOp.getResult());
659+
} else {
660+
finalResults.push_back(result);
661+
}
662+
} else {
663+
finalResults.push_back(result);
664+
}
665+
}
666+
667+
// Replace the generic operation with the final results
668+
rewriter.replaceOp(genericOp, finalResults);
621669

622670
return success();
623671
}

0 commit comments

Comments
 (0)