@@ -687,7 +687,8 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
687687 end
688688 end
689689 end
690- push! (function_attributes (llvmf), EnumAttribute (" alwaysinline" , 0 ))
690+
691+ # push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
691692
692693 needsTape = ! isghostty (TapeT) && ! Core. Compiler. isconstType (TapeT)
693694
@@ -711,22 +712,37 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
711712
712713 swiftself = any (any (map (k-> kind (k)== kind (EnumAttribute (" swiftself" )), collect (parameter_attributes (llvmf, i)))) for i in 1 : length (collect (parameters (llvmf))))
713714
714- _, sret, returnRoots = get_return_info (enzyme_custom_extract_mi (llvmf)[2 ])
715+ miRT = enzyme_custom_extract_mi (llvmf)[2 ]
716+ _, sret, returnRoots = get_return_info (miRT)
715717
716718 if ! forward
719+ funcTy = rev_TT. parameters[isKWCall ? 4 : 2 ]
717720 if needsTape
718721 @assert tape != C_NULL
719- tape_idx = 1 + (kwtup!= = nothing && ! isghostty (kwtup))+ (isKWCall && ! isghostty (rev_TT. parameters[4 ]))
720- innerTy = value_type (parameters (llvmf)[tape_idx+ (sret != = nothing )+ (RT <: Active )])
722+ tape_idx = 1 + (kwtup!= = nothing && ! isghostty (kwtup))+ (isKWCall && ! isghostty (rev_TT. parameters[4 ])) + ! isghostty (funcTy)
723+ trueidx = tape_idx+ (sret != = nothing )+ (returnRoots != = nothing )+ swiftself+ (RT <: Active )
724+ innerTy = value_type (parameters (llvmf)[trueidx])
721725 if innerTy != value_type (tape)
722- if isabstracttype (TapeT)
726+ if isabstracttype (TapeT) || TapeT == Tuple || TapeT . layout == C_NULL
723727 msg = sprint () do io
724728 println (io, " Enzyme : mismatch between innerTy $innerTy and tape type $(value_type (tape)) " )
725729 println (io, " tape_idx=" , tape_idx)
730+ println (io, " true_idx=" , trueidx)
731+ println (io, " isKWCall=" , isKWCall)
732+ println (io, " kwtup=" , kwtup)
733+ println (io, " funcTy=" , funcTy)
734+ println (io, " isghostty(funcTy)=" , isghostty (funcTy))
735+ println (io, " miRT=" , miRT)
726736 println (io, " sret=" , sret)
737+ println (io, " returnRoots=" , returnRoots)
738+ println (io, " swiftself=" , swiftself)
727739 println (io, " RT=" , RT)
728740 println (io, " tape=" , tape)
729- println (io, " llvmf=" , string (llvmf))
741+ println (io, " llvmf=" , string (LLVM. function_type (llvmf)))
742+ println (io, " TapeT=" , TapeT)
743+ println (io, " mi=" , mi)
744+ println (io, " ami=" , ami)
745+ println (io, " rev_TT =" , rev_TT)
730746 end
731747 throw (AssertionError (msg))
732748 end
@@ -749,7 +765,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
749765 val = LLVM. Value (API. EnzymeGradientUtilsDiffe (gutils, orig, B))
750766 else
751767 llety = convert (LLVMType, eltype (RT))
752- ptr_val = invert_pointer (gutils, operands (orig)[1 ], B)
768+ ptr_val = invert_pointer (gutils, operands (orig)[1 + ! isghostty (funcTy) ], B)
753769 val = UndefValue (LLVM. LLVMType (API. EnzymeGetShadowType (width, llety)))
754770 for idx in 1 : width
755771 ev = (width == 1 ) ? ptr_val : extract_value! (B, ptr_val, idx- 1 )
@@ -769,8 +785,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,
769785 if any_jltypes (llty)
770786 emit_writebarrier! (B, get_julia_inner_types (B, al0, val))
771787 end
772-
773- insert! (args, 1 + (kwtup!= = nothing && ! isghostty (kwtup))+ (isKWCall && ! isghostty (rev_TT. parameters[4 ])), al)
788+ insert! (args, 1 + (! isghostty (funcTy))+ (kwtup!= = nothing && ! isghostty (kwtup))+ (isKWCall && ! isghostty (rev_TT. parameters[4 ])), al)
774789 end
775790 end
776791
0 commit comments