971971 end
972972 LLVM. API. LLVMAddCallSiteAttribute (res, LLVM. API. LLVMAttributeIndex (1 ), attr)
973973 res = load! (B, sty, sret)
974+ if returnRoots != = nothing && VERSION >= v " 1.12"
975+ res = recombine_value! (B, res, returnRoots; must_cache= true )
976+ end
974977 end
975978 if swiftself
976979 attr = EnumAttribute (" swiftself" )
@@ -1000,9 +1003,19 @@ end
10001003 if RT <: Const
10011004 if needsPrimal
10021005 @assert RealRt == fwd_RT
1003- if get_return_info (RealRt)[2 ] != = nothing
1006+ _, prim_sret, prim_roots = get_return_info (RealRt)
1007+ if prim_sret != = nothing
10041008 val = new_from_original (gutils, operands (orig)[1 ])
1005- store! (B, res, val)
1009+
1010+ if prim_roots != = nothing && VERSION >= v " 1.12"
1011+ extract_nonjlvalues_into! (B, value_type (res), val, res)
1012+
1013+ rval = new_from_original (gutils, operands (orig)[2 ])
1014+
1015+ extract_roots_from_value! (B, res, rval)
1016+ else
1017+ store! (B, res, val)
1018+ end
10061019 else
10071020 normalV = res. ref
10081021 end
@@ -1016,12 +1029,28 @@ end
10161029 ST = NTuple{Int (width),ST}
10171030 end
10181031 @assert ST == fwd_RT
1019- if get_return_info (RealRt)[2 ] != = nothing
1032+ _, prim_sret, prim_roots = get_return_info (RealRt)
1033+ if prim_sret != = nothing
10201034 dval_ptr = invert_pointer (gutils, operands (orig)[1 ], B)
1021- for idx = 1 : width
1035+
1036+ droots = if prim_roots != = nothing && VERSION >= v " 1.12"
1037+ @assert ! is_constant_value (gutils, operands (orig)[2 ])
1038+ invert_pointer (gutils, operands (orig)[2 ], B)
1039+ end
1040+
1041+ for idx = 1 : width
10221042 ev = (width == 1 ) ? dval : extract_value! (B, dval, idx - 1 )
10231043 pev = (width == 1 ) ? dval_ptr : extract_value! (B, dval_ptr, idx - 1 )
1024- store! (B, res, pev)
1044+
1045+ if prim_roots != = nothing && VERSION >= v " 1.12"
1046+ extract_nonjlvalues_into! (B, value_type (ev), pev, ev)
1047+
1048+ rval = (width == 1 ) ? droots : extract_value! (B, droots, idx - 1 )
1049+
1050+ extract_roots_from_value! (B, ev, rval)
1051+ else
1052+ store! (B, ev, pev)
1053+ end
10251054 end
10261055 else
10271056 shadowV = res. ref
@@ -1033,16 +1062,42 @@ end
10331062 BatchDuplicated{RealRt,Int (width)}
10341063 end
10351064 @assert ST == fwd_RT
1036- if get_return_info (RealRt)[2 ] != = nothing
1065+
1066+ _, prim_sret, prim_roots = get_return_info (RealRt)
1067+ if prim_sret != = nothing
10371068 val = new_from_original (gutils, operands (orig)[1 ])
1038- store! (B, extract_value! (B, res, 0 ), val)
1069+
1070+ res0 = extract_value! (B, res, 0 )
1071+ if prim_roots != = nothing && VERSION >= v " 1.12"
1072+ extract_nonjlvalues_into! (B, value_type (res0), val, res0)
1073+
1074+ rval = new_from_original (gutils, operands (orig)[2 ])
1075+
1076+ extract_roots_from_value! (B, res0, rval)
1077+ else
1078+ store! (B, res0, val)
1079+ end
10391080
10401081 dval_ptr = invert_pointer (gutils, operands (orig)[1 ], B)
10411082 dval = extract_value! (B, res, 1 )
1042- for idx = 1 : width
1083+
1084+ droots = if prim_roots != = nothing && VERSION >= v " 1.12"
1085+ @assert ! is_constant_value (gutils, operands (orig)[2 ])
1086+ invert_pointer (gutils, operands (orig)[2 ], B)
1087+ end
1088+
1089+ for idx = 1 : width
10431090 ev = (width == 1 ) ? dval : extract_value! (B, dval, idx - 1 )
10441091 pev = (width == 1 ) ? dval_ptr : extract_value! (B, dval_ptr, idx - 1 )
1045- store! (B, ev, pev)
1092+ if prim_roots != = nothing && VERSION >= v " 1.12"
1093+ extract_nonjlvalues_into! (B, value_type (ev), pev, ev)
1094+
1095+ rval = (width == 1 ) ? droots : extract_value! (B, droots, idx - 1 )
1096+
1097+ extract_roots_from_value! (B, ev, rval)
1098+ else
1099+ store! (B, ev, pev)
1100+ end
10461101 end
10471102 else
10481103 normalV = extract_value! (B, res, 0 ). ref
@@ -1781,6 +1836,9 @@ function enzyme_custom_common_rev(
17811836 )
17821837 res = load! (B, sty, sret)
17831838 API. SetMustCache! (res)
1839+ if returnRoots != = nothing && VERSION >= v " 1.12"
1840+ res = recombine_value! (B, res, returnRoots; must_cache= true )
1841+ end
17841842 end
17851843 if swiftself
17861844 attr = EnumAttribute (" swiftself" )
@@ -1888,9 +1946,19 @@ function enzyme_custom_common_rev(
18881946 if needsPrimal
18891947 @assert ! isghostty (RealRt)
18901948 normalV = extract_value! (B, resV, idx)
1891- if get_return_info (RealRt)[2 ] != = nothing
1949+ _, prim_sret, prim_roots = get_return_info (RealRt)
1950+ if prim_sret != = nothing
18921951 val = new_from_original (gutils, operands (orig)[1 ])
1893- store! (B, normalV, val)
1952+
1953+ if prim_roots != = nothing && VERSION >= v " 1.12"
1954+ extract_nonjlvalues_into! (B, value_type (normalV), val, normalV)
1955+
1956+ rval = new_from_original (gutils, operands (orig)[2 ])
1957+
1958+ extract_roots_from_value! (B, normalV, rval)
1959+ else
1960+ store! (B, normalV, val)
1961+ end
18941962 else
18951963 @assert value_type (normalV) == value_type (orig)
18961964 normalV = normalV. ref
@@ -1901,16 +1969,30 @@ function enzyme_custom_common_rev(
19011969 if needsShadowJL
19021970 @assert ! isghostty (RealRt)
19031971 shadowV = extract_value! (B, resV, idx)
1904- if get_return_info (RealRt)[2 ] != = nothing
1972+ _, prim_sret, prim_roots = get_return_info (RealRt)
1973+ if prim_sret != = nothing
19051974 dval = invert_pointer (gutils, operands (orig)[1 ], B)
19061975
1907- for idx = 1 : width
1976+ droots = if prim_roots != = nothing && VERSION >= v " 1.12"
1977+ @assert ! is_constant_value (gutils, operands (orig)[2 ])
1978+ invert_pointer (gutils, operands (orig)[2 ], B)
1979+ end
1980+
1981+ for idx = 1 : width
19081982 to_store =
19091983 (width == 1 ) ? shadowV : extract_value! (B, shadowV, idx - 1 )
19101984
19111985 store_ptr = (width == 1 ) ? dval : extract_value! (B, dval, idx - 1 )
19121986
1913- store! (B, to_store, store_ptr)
1987+ if prim_roots != = nothing && VERSION >= v " 1.12"
1988+ extract_nonjlvalues_into! (B, value_type (to_store), store_ptr, to_store)
1989+
1990+ rval = (width == 1 ) ? droots : extract_value! (B, droots, idx - 1 )
1991+
1992+ extract_roots_from_value! (B, to_store, rval)
1993+ else
1994+ store! (B, to_store, store_ptr)
1995+ end
19141996 end
19151997 shadowV = C_NULL
19161998 else
0 commit comments