Skip to content

Commit f7ecafa

Browse files
committed
Enhance Privacy.all_methods
1 parent 1aff976 commit f7ecafa

File tree

2 files changed

+93
-75
lines changed

2 files changed

+93
-75
lines changed

src/piracy.jl

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,46 @@
11
module Piracy
22

3-
using Test: @test, @test_broken
4-
using ..Aqua: walkmodules
5-
6-
const DEFAULT_PKGS = (Base.PkgId(Base), Base.PkgId(Core))
7-
8-
function all_methods!(
9-
mod::Module,
10-
done_callables::Base.IdSet{Any}, # cached to prevent duplicates
11-
result::Vector{Method},
12-
filter_default::Bool,
13-
)::Vector{Method}
14-
for name in names(mod; all = true, imported = true)
15-
# names can list undefined symbols which cannot be eval'd
16-
isdefined(mod, name) || continue
17-
18-
# Skip closures
19-
startswith(String(name), "#") && continue
20-
val = getfield(mod, name)
21-
22-
if !in(val, done_callables)
23-
# In old versions of Julia, Vararg errors when methods is called on it
24-
val === Vararg && continue
25-
for method in methods(val)
26-
# Default filtering removes all methods defined in DEFAULT_PKGs,
27-
# since these may pirate each other.
28-
if !(filter_default && in(Base.PkgId(method.module), DEFAULT_PKGS))
29-
push!(result, method)
30-
end
31-
end
32-
push!(done_callables, val)
3+
import Test
4+
5+
# based on Test/Test.jl#detect_ambiguities
6+
# https://github.com/JuliaLang/julia/blob/v1.9.1/stdlib/Test/src/Test.jl#L1838-L1896
7+
function all_methods(mods::Module...; skip_deprecated::Bool)
8+
meths = Method[]
9+
mods = collect(mods)::Vector{Module}
10+
11+
function examine(mt::Core.MethodTable)
12+
examine(Base.MethodList(mt))
13+
end
14+
function examine(ml::Base.MethodList)
15+
for m in ml
16+
Test.is_in_mods(m.module, true, mods) || continue
17+
push!(meths, m)
3318
end
3419
end
35-
result
36-
end
3720

38-
function all_methods(mod::Module; filter_default::Bool = true)
39-
result = Method[]
40-
done_callables = Base.IdSet()
41-
walkmodules(mod) do mod
42-
all_methods!(mod, done_callables, result, filter_default)
21+
work = Base.loaded_modules_array()
22+
filter!(mod -> mod === parentmodule(mod), work) # some items in loaded_modules_array are not top modules (really just Base)
23+
while !isempty(work)
24+
mod = pop!(work)
25+
for name in names(mod; all = true)
26+
(skip_deprecated && Base.isdeprecated(mod, name)) && continue
27+
isdefined(mod, name) || continue
28+
f = Base.unwrap_unionall(getfield(mod, name))
29+
if isa(f, Module) && f !== mod && parentmodule(f) === mod && nameof(f) === name
30+
push!(work, f)
31+
elseif isa(f, DataType) &&
32+
isdefined(f.name, :mt) &&
33+
parentmodule(f) === mod &&
34+
nameof(f) === name &&
35+
f.name.mt !== Symbol.name.mt &&
36+
f.name.mt !== DataType.name.mt
37+
examine(f.name.mt)
38+
end
39+
end
4340
end
44-
return result
41+
examine(Symbol.name.mt)
42+
examine(DataType.name.mt)
43+
return meths
4544
end
4645

4746
##################################
@@ -141,7 +140,7 @@ function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as
141140

142141
# fallback to general code
143142
return !(T in treat_as_own) &&
144-
!(T <: Function && T.instance in treat_as_own) &&
143+
!(T <: Function && isdefined(T, :instance) && T.instance in treat_as_own) &&
145144
is_foreign(T, pkg; treat_as_own = treat_as_own)
146145
end
147146

@@ -165,8 +164,8 @@ end
165164
hunt(mod::Module; from::Module = mod, kwargs...) =
166165
hunt(Base.PkgId(mod); from = from, kwargs...)
167166

168-
function hunt(pkg::Base.PkgId; from::Module, kwargs...)
169-
filter(all_methods(from)) do method
167+
function hunt(pkg::Base.PkgId; from::Module, skip_deprecated::Bool = true, kwargs...)
168+
filter(all_methods(from; skip_deprecated = skip_deprecated)) do method
170169
Base.PkgId(method.module) === pkg && is_pirate(method; kwargs...)
171170
end
172171
end
@@ -182,6 +181,7 @@ See [Julia documentation](https://docs.julialang.org/en/v1/manual/style-guide/#A
182181
# Keyword Arguments
183182
- `broken::Bool = false`: If true, it uses `@test_broken` instead of
184183
`@test`.
184+
- `skip_deprecated::Bool = true`: If true, it does not check deprecated methods.
185185
- `treat_as_own = Union{Function, Type}[]`: The types in this container
186186
are considered to be "owned" by the module `m`. This is useful for
187187
testing packages that deliberately commit some type piracy, e.g. modules

test/test_piracy.jl

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -57,31 +57,27 @@ Base.findmin(::ForeignParameterizedType{Int}, x::Int) = x + 1
5757
Base.findmin(::Set{Vector{ForeignParameterizedType{Int}}}, x::Int) = x + 1
5858
Base.findmin(::Union{Foo,ForeignParameterizedType{Int}}, x::Int) = x + 1
5959

60-
# Assign them names in this module so they can be found by all_methods
61-
a = Base.findfirst
62-
b = Base.findlast
63-
c = Base.findmax
64-
d = Base.findmin
6560
end # PiracyModule
6661

6762
using Aqua: Piracy
68-
using PiracyForeignProject: ForeignType, ForeignParameterizedType
63+
using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType
6964

7065
# Get all methods - test length
71-
meths = filter(Piracy.all_methods(PiracyModule)) do m
66+
meths = filter(Piracy.all_methods(PiracyModule; skip_deprecated = true)) do m
7267
m.module == PiracyModule
7368
end
7469

75-
# 2 Foo constructors
76-
# 2 from f
77-
# 1 from MyUnion
78-
# 6 from findlast
79-
# 3 from findfirst
80-
# 1 from ForeignType
81-
# 1 from ForeignNonSingletonType
82-
# 3 from findmax
83-
# 3 from findmin
84-
@test length(meths) == 2 + 2 + 1 + 6 + 3 + 1 + 1 + 3 + 3
70+
@test length(meths) ==
71+
2 + # Foo constructors
72+
1 + # Bar constructor
73+
2 + # f
74+
1 + # MyUnion
75+
6 + # findlast
76+
3 + # findfirst
77+
1 + # ForeignType callable
78+
1 + # ForeignNonSingletonType callable
79+
3 + # findmax
80+
3 # findmin
8581

8682
# Test what is foreign
8783
BasePkg = Base.PkgId(Base)
@@ -95,49 +91,71 @@ ThisPkg = Base.PkgId(PiracyModule)
9591

9692
# Test what is pirate
9793
pirates = filter(m -> Piracy.is_pirate(m), meths)
98-
@test length(pirates) == 3 + 3 + 3 + 1 + 1
99-
@test_broken all(pirates) do m
100-
m.name in [:findfirst, :findmax, :findmin]
94+
@test length(pirates) ==
95+
3 + # findfirst
96+
3 + # findmax
97+
3 + # findmin
98+
1 + # ForeignType callable
99+
1 # ForeignNonSingletonType callable
100+
@test all(pirates) do m
101+
m.name in [:findfirst, :findmax, :findmin, :ForeignType, :ForeignNonSingletonType]
101102
end
102103

103104
# Test what is pirate (with treat_as_own=[ForeignType])
104105
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignType]), meths)
105-
@test_broken length(pirates) == 3 + 3
106-
@test_broken all(pirates) do m
107-
m.name in [:findfirst, :findmin]
106+
@test length(pirates) ==
107+
3 + # findfirst
108+
3 + # findmin
109+
1 # ForeignNonSingletonType callable
110+
@test all(pirates) do m
111+
m.name in [:findfirst, :findmin, :ForeignNonSingletonType]
108112
end
109113

110114
# Test what is pirate (with treat_as_own=[ForeignParameterizedType])
111115
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignParameterizedType]), meths)
112-
@test_broken length(pirates) == 3 + 3
113-
@test_broken all(pirates) do m
114-
m.name in [:findfirst, :findmax]
116+
@test length(pirates) ==
117+
3 + # findfirst
118+
3 + # findmax
119+
1 + # ForeignType callable
120+
1 # ForeignNonSingletonType callable
121+
@test all(pirates) do m
122+
m.name in [:findfirst, :findmax, :ForeignType, :ForeignNonSingletonType]
115123
end
116124

117125
# Test what is pirate (with treat_as_own=[ForeignType, ForeignParameterizedType])
118126
pirates = filter(
119127
m -> Piracy.is_pirate(m; treat_as_own = [ForeignType, ForeignParameterizedType]),
120128
meths,
121129
)
122-
@test_broken length(pirates) == 3
123-
@test_broken all(pirates) do m
124-
m.name in [:findfirst]
130+
@test length(pirates) ==
131+
3 + # findfirst
132+
1 # ForeignNonSingletonType callable
133+
@test all(pirates) do m
134+
m.name in [:findfirst, :ForeignNonSingletonType]
125135
end
126136

127137
# Test what is pirate (with treat_as_own=[Base.findfirst, Base.findmax])
128138
pirates =
129139
filter(m -> Piracy.is_pirate(m; treat_as_own = [Base.findfirst, Base.findmax]), meths)
130-
@test_broken length(pirates) == 3
131-
@test_broken all(pirates) do m
132-
m.name in [:findmin]
140+
@test length(pirates) ==
141+
3 + # findmin
142+
1 + # ForeignType callable
143+
1 # ForeignNonSingletonType callable
144+
@test all(pirates) do m
145+
m.name in [:findmin, :ForeignType, :ForeignNonSingletonType]
133146
end
134147

135148
# Test what is pirate (excluding a cover of everything)
136149
pirates = filter(
137150
m -> Piracy.is_pirate(
138151
m;
139-
treat_as_own = [ForeignType, ForeignParameterizedType, Base.findfirst],
152+
treat_as_own = [
153+
ForeignType,
154+
ForeignParameterizedType,
155+
ForeignNonSingletonType,
156+
Base.findfirst,
157+
],
140158
),
141159
meths,
142160
)
143-
@test_broken length(pirates) == 0
161+
@test length(pirates) == 0

0 commit comments

Comments
 (0)