Skip to content

Commit ede4493

Browse files
authored
feat: more dispatches for any/all (EnzymeAD#834)
1 parent ff32540 commit ede4493

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/TracedRArray.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -787,8 +787,10 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN
787787
end
788788
end
789789

790-
Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x)
791-
Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x)
790+
Base._all(f, x::AnyTracedRArray, dims) = mapreduce(f, &, x; dims)
791+
Base._all(f, x::AnyTracedRArray, dims::Colon) = mapreduce(f, &, x; dims)
792+
Base._any(f, x::AnyTracedRArray, dims) = mapreduce(f, |, x; dims)
793+
Base._any(f, x::AnyTracedRArray, dims::Colon) = mapreduce(f, |, x; dims)
792794

793795
# outer repeat
794796
function Base._RepeatInnerOuter.repeat_outer(

0 commit comments

Comments
 (0)