@@ -1727,6 +1727,7 @@ end
17271727 else
17281728 " Unknown object of type" * " " * string (TT)
17291729 end
1730+ @assert ! illegal
17301731 illegalVal = cur
17311732 illegal = true
17321733 return make_batched (ncur, prevbb)
@@ -1770,6 +1771,7 @@ end
17701771 end
17711772
17721773 cur2 = if changed
1774+ @assert ! illegal
17731775 illegalVal = cur
17741776 illegal = true
17751777 # TODO replace with correct insertions/splats
@@ -1942,8 +1944,10 @@ end
19421944 return make_batched (ncur, prevbb)
19431945 end
19441946
1945- illegal = true
1946- illegalVal = cur
1947+ if ! illegal
1948+ illegal = true
1949+ illegalVal = cur
1950+ end
19471951 return ncur
19481952 end
19491953
@@ -7070,10 +7074,48 @@ end
70707074 ctx = LLVM. context (mod)
70717075 for f in functions (mod), bb in blocks (f), inst in instructions (bb)
70727076 fn = isa (inst, LLVM. CallInst) ? LLVM. called_operand (inst) : nothing
7077+
7078+ if ! API. HasFromStack (inst) && isa (inst, LLVM. AllocaInst)
7079+
7080+ calluse = nothing
7081+ for u in LLVM. uses (inst)
7082+ u = LLVM. user (u)
7083+ if isa (u, LLVM. CallInst) && operands (u)[1 ] == inst
7084+
7085+ sretkind = kind (if LLVM. version (). major >= 12
7086+ TypeAttribute (" sret" , LLVM. Int32Type ())
7087+ else
7088+ EnumAttribute (" sret" )
7089+ end )
7090+ hassret = false
7091+ llvmfn = LLVM. called_operand (u)
7092+ if llvmfn isa LLVM. Function
7093+ for attr in collect (parameter_attributes (llvmfn, 1 ))
7094+ if kind (attr) == sretkind
7095+ hassret = true
7096+ break
7097+ end
7098+ end
7099+ end
7100+ if hassret
7101+ calluse = u
7102+ end
7103+ end
7104+ end
7105+ if calluse isa LLVM. CallInst
7106+ _, RT = enzyme_custom_extract_mi (calluse, false )
7107+ if RT != = nothing
7108+ llrt, sret, returnRoots = get_return_info (RT)
7109+ if ! (sret isa Nothing) && ! is_sret_union (RT)
7110+ metadata (inst)[" enzymejl_allocart" ] = MDNode (LLVM. Metadata[MDString (string (convert (UInt, unsafe_to_pointer (RT))))])
7111+ end
7112+ end
7113+ end
7114+ end
70737115
70747116 if ! API. HasFromStack (inst) &&
70757117 ((isa (inst, LLVM. CallInst) &&
7076- (! isa (fn, LLVM. Function) || isempty (blocks (fn))) ) || isa (inst, LLVM. LoadInst))
7118+ (! isa (fn, LLVM. Function) || isempty (blocks (fn))) ) || isa (inst, LLVM. LoadInst) || isa (inst, LLVM . AllocaInst) )
70777119 legal, source_typ, byref = abs_typeof (inst)
70787120 codegen_typ = value_type (inst)
70797121 if legal
0 commit comments