999999 end
10001000end
10011001
1002+ @inline function override_bc_foldl (op, init, itr)
1003+ # Unroll the while loop once; if init is known, the call to op may
1004+ # be evaluated at compile time
1005+ y = iterate (itr)
1006+ y === nothing && return init
1007+ v = op (init, y[1 ])
1008+
1009+ if same_sized (itr. args)
1010+ @inbounds @simd for I in 2 : length (itr)
1011+ val = overload_broadcast_getindex (itr, I)
1012+ v = op (v, val)
1013+ end
1014+ else
1015+ while true
1016+ y = iterate (itr, y[2 ])
1017+ y === nothing && break
1018+ v = op (v, y[1 ])
1019+ end
1020+ end
1021+ return v
1022+ end
1023+
10021024struct MultiOp{Position, NumUsed, F1, F2}
10031025 f1:: F1
10041026 f2:: F2
@@ -1025,17 +1047,19 @@ end
10251047 end
10261048end
10271049
1028- @inline function bc_or_array_or_number_ty (@nospecialize (Ty:: Type )):: Bool
1029- if Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}
1030- return all (bc_or_array_or_number_ty, Ty. parameters[4 ]. parameters)
1050+ @inline function bc_or_array_or_number_ty (@nospecialize (Ty:: Type ), midnothing:: Bool = true ):: Bool
1051+ if ( midnothing && Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} ) ||
1052+ (! midnothing && Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle} )
1053+ return all (Base. Fix2 (bc_or_array_or_number_ty, midnothing), Ty. parameters[4 ]. parameters)
10311054 else
10321055 return Ty <: AbstractArray || Ty <: Number || Ty <: Base.RefValue
10331056 end
10341057end
10351058
1036- @inline function has_array (@nospecialize (Ty:: Type )):: Bool
1037- if Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}
1038- return any (has_array, Ty. parameters[4 ]. parameters)
1059+ @inline function has_array (@nospecialize (Ty:: Type ), midnothing:: Bool = true ):: Bool
1060+ if ( midnothing && Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} ) ||
1061+ (! midnothing && Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle} )
1062+ return any (Base. Fix2 (has_array, midnothing), Ty. parameters[4 ]. parameters)
10391063 else
10401064 return Ty <: AbstractArray
10411065 end
@@ -1157,6 +1181,32 @@ function abstract_call_known(
11571181 )
11581182 end
11591183 end
1184+
1185+ if f === Base. _foldl_impl && length (argtypes) == 4
1186+
1187+ bcty = widenconst (argtypes[4 ])
1188+
1189+
1190+ if widenconst (argtypes[3 ]) <: Base._InitialValue &&
1191+ bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle} && ndims (bcty) >= 2 &&
1192+ bc_or_array_or_number_ty (bcty, false ) && has_array (bcty, false )
1193+
1194+ arginfo2 = ArgInfo (
1195+ fargs isa Nothing ? nothing :
1196+ [:(Enzyme. Compiler. Interpreter. override_bc_foldl), fargs[2 : end ]. .. ],
1197+ [Core. Const (Enzyme. Compiler. Interpreter. override_bc_foldl), argtypes[2 : end ]. .. ],
1198+ )
1199+
1200+ return Base. @invoke abstract_call_known (
1201+ interp:: AbstractInterpreter ,
1202+ Enzyme. Compiler. Interpreter. override_bc_foldl:: Any ,
1203+ arginfo2:: ArgInfo ,
1204+ si:: StmtInfo ,
1205+ sv:: AbsIntState ,
1206+ max_methods:: Int ,
1207+ )
1208+ end
1209+ end
11601210 end
11611211
11621212 @static if VERSION < v " 1.11.0-"
0 commit comments