Skip to content

Commit dcbb34a

Browse files
authored
Foldl broadcast speedup (#2630)
1 parent df2a49a commit dcbb34a

File tree

1 file changed

+56
-6
lines changed

1 file changed

+56
-6
lines changed

src/compiler/interpreter.jl

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,28 @@ end
999999
end
10001000
end
10011001

1002+
@inline function override_bc_foldl(op, init, itr)
1003+
# Unroll the while loop once; if init is known, the call to op may
1004+
# be evaluated at compile time
1005+
y = iterate(itr)
1006+
y === nothing && return init
1007+
v = op(init, y[1])
1008+
1009+
if same_sized(itr.args)
1010+
@inbounds @simd for I in 2:length(itr)
1011+
val = overload_broadcast_getindex(itr, I)
1012+
v = op(v, val)
1013+
end
1014+
else
1015+
while true
1016+
y = iterate(itr, y[2])
1017+
y === nothing && break
1018+
v = op(v, y[1])
1019+
end
1020+
end
1021+
return v
1022+
end
1023+
10021024
struct MultiOp{Position, NumUsed, F1, F2}
10031025
f1::F1
10041026
f2::F2
@@ -1025,17 +1047,19 @@ end
10251047
end
10261048
end
10271049

1028-
@inline function bc_or_array_or_number_ty(@nospecialize(Ty::Type))::Bool
1029-
if Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}
1030-
return all(bc_or_array_or_number_ty, Ty.parameters[4].parameters)
1050+
@inline function bc_or_array_or_number_ty(@nospecialize(Ty::Type), midnothing::Bool=true)::Bool
1051+
if ( midnothing && Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}) ||
1052+
(!midnothing && Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle})
1053+
return all(Base.Fix2(bc_or_array_or_number_ty, midnothing), Ty.parameters[4].parameters)
10311054
else
10321055
return Ty <: AbstractArray || Ty <: Number || Ty <: Base.RefValue
10331056
end
10341057
end
10351058

1036-
@inline function has_array(@nospecialize(Ty::Type))::Bool
1037-
if Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}
1038-
return any(has_array, Ty.parameters[4].parameters)
1059+
@inline function has_array(@nospecialize(Ty::Type), midnothing::Bool=true)::Bool
1060+
if ( midnothing && Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}) ||
1061+
(!midnothing && Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle})
1062+
return any(Base.Fix2(has_array, midnothing), Ty.parameters[4].parameters)
10391063
else
10401064
return Ty <: AbstractArray
10411065
end
@@ -1157,6 +1181,32 @@ function abstract_call_known(
11571181
)
11581182
end
11591183
end
1184+
1185+
if f === Base._foldl_impl && length(argtypes) == 4
1186+
1187+
bcty = widenconst(argtypes[4])
1188+
1189+
1190+
if widenconst(argtypes[3]) <: Base._InitialValue &&
1191+
bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle} && ndims(bcty) >= 2 &&
1192+
bc_or_array_or_number_ty(bcty, false) && has_array(bcty, false)
1193+
1194+
arginfo2 = ArgInfo(
1195+
fargs isa Nothing ? nothing :
1196+
[:(Enzyme.Compiler.Interpreter.override_bc_foldl), fargs[2:end]...],
1197+
[Core.Const(Enzyme.Compiler.Interpreter.override_bc_foldl), argtypes[2:end]...],
1198+
)
1199+
1200+
return Base.@invoke abstract_call_known(
1201+
interp::AbstractInterpreter,
1202+
Enzyme.Compiler.Interpreter.override_bc_foldl::Any,
1203+
arginfo2::ArgInfo,
1204+
si::StmtInfo,
1205+
sv::AbsIntState,
1206+
max_methods::Int,
1207+
)
1208+
end
1209+
end
11601210
end
11611211

11621212
@static if VERSION < v"1.11.0-"

0 commit comments

Comments
 (0)