@@ -605,19 +605,67 @@ class LinalgGenericToKernelPattern : public OpRewritePattern<GenericOp> {
605605
606606 llvm::errs () << " DEBUG: Successfully mapped all kernel arguments, creating kernel.launch\n " ;
607607
608+ // Get kernel function signature types for casting
609+ auto kernelFuncType = matchedDefnOp.getFunctionType ();
610+ auto kernelInputTypes = kernelFuncType.getInputs ();
611+ auto kernelResultTypes = kernelFuncType.getResults ();
612+
613+ // Cast operands to match kernel signature types if needed
614+ SmallVector<Value> castedOperands;
615+ for (size_t i = 0 ; i < operands.size (); ++i) {
616+ Value operand = operands[i];
617+ Type expectedType = (i < kernelInputTypes.size ()) ? kernelInputTypes[i] : operand.getType ();
618+
619+ if (operand.getType () != expectedType) {
620+ // Insert tensor.cast for type conversion
621+ if (isa<RankedTensorType>(operand.getType ()) && isa<RankedTensorType>(expectedType)) {
622+ llvm::errs () << " DEBUG: Casting operand " << i << " from " << operand.getType ()
623+ << " to " << expectedType << " \n " ;
624+ auto castOp = rewriter.create <tensor::CastOp>(loc, expectedType, operand);
625+ castedOperands.push_back (castOp.getResult ());
626+ } else {
627+ // For non-tensor types, use the operand as-is
628+ castedOperands.push_back (operand);
629+ }
630+ } else {
631+ castedOperands.push_back (operand);
632+ }
633+ }
634+
608635 // Get result types from the generic operation
609- TypeRange resultTypes = genericOp.getResultTypes ();
636+ TypeRange originalResultTypes = genericOp.getResultTypes ();
610637
611- // Create the kernel.launch operation
638+ // Create the kernel.launch operation with casted operands and kernel result types
612639 auto launchOp = rewriter.create <kernel::LaunchOp>(
613640 loc,
614- resultTypes,
641+ kernelResultTypes, // Use kernel result types for the launch op
615642 opName,
616- operands
643+ castedOperands // Use casted operands
617644 );
618645
619- // Replace the generic operation with the launch operation
620- rewriter.replaceOp (genericOp, launchOp.getResults ());
646+ // Cast results back to original types if needed
647+ SmallVector<Value> finalResults;
648+ for (size_t i = 0 ; i < launchOp.getResults ().size (); ++i) {
649+ Value result = launchOp.getResult (i);
650+ Type originalType = (i < originalResultTypes.size ()) ? originalResultTypes[i] : result.getType ();
651+
652+ if (result.getType () != originalType) {
653+ // Insert tensor.cast to convert back to original type
654+ if (isa<RankedTensorType>(result.getType ()) && isa<RankedTensorType>(originalType)) {
655+ llvm::errs () << " DEBUG: Casting result " << i << " from " << result.getType ()
656+ << " to " << originalType << " \n " ;
657+ auto castOp = rewriter.create <tensor::CastOp>(loc, originalType, result);
658+ finalResults.push_back (castOp.getResult ());
659+ } else {
660+ finalResults.push_back (result);
661+ }
662+ } else {
663+ finalResults.push_back (result);
664+ }
665+ }
666+
667+ // Replace the generic operation with the final results
668+ rewriter.replaceOp (genericOp, finalResults);
621669
622670 return success ();
623671 }
0 commit comments