diff --git a/src/Aqua.jl b/src/Aqua.jl index 2f203016..fcc359ba 100644 --- a/src/Aqua.jl +++ b/src/Aqua.jl @@ -4,6 +4,8 @@ using Base: PkgId, UUID using Pkg: Pkg, TOML using Test +const JULIA_HAS_EXTENSIONS = isdefined(Base, :get_extension) # introduced in v1.9 + try findnext('a', "a", 1) catch diff --git a/src/ambiguities.jl b/src/ambiguities.jl index 5c9057cc..373d6288 100644 --- a/src/ambiguities.jl +++ b/src/ambiguities.jl @@ -18,6 +18,12 @@ false-positive. callable (sometimes also called "functor") and the constructor. That is to say, `MyModule.MyType` means to ignore ambiguities between `(::MyType)(x, y::Int)` and `(::MyType)(x::Int, y)`. +- `extension_combinations = :default`: In julia versions 1.8 and before, + this keyword argument is ignored. In julia versions 1.9 and later, + it contains the combinations of extensions to test, as a vector of + vectors of strings or symbols. The default value `:default` tests + a small amount of combinations, at least *no extensions* and *all + extensions*. - `recursive::Bool = true`: Passed to `Test.detect_ambiguities`. Note that the default here (`true`) is different from `detect_ambiguities`. This is for testing ambiguities in methods @@ -84,27 +90,87 @@ function reprexclude(exspecs::Vector{ExcludeSpec}) return string("Aqua.ExcludeSpec[", join(itemreprs, ", "), "]") end -function _test_ambiguities(packages::Vector{PkgId}; broken::Bool = false, kwargs...) - num_ambiguities, strout, strerr = _find_ambiguities(packages; kwargs...) +function _test_ambiguities( + packages::Vector{PkgId}; + broken::Bool = false, + extension_combinations = :default, + kwargs..., +) + if extension_combinations == :default + extension_combinations = Vector{String}[] + push!(extension_combinations, String[]) + @static if JULIA_HAS_EXTENSIONS + all_exts = String[] + for pkg in setdiff(packages, [PkgId(Base), PkgId(Core)]) + exts, _, _ = get_extension_data_from_toml(pkg) + for e in keys(exts) + push!(extension_combinations, [e]) + end + push!(extension_combinations, collect(keys(exts))) + append!(all_exts, collect(keys(exts))) + end + push!(extension_combinations, all_exts) + end + unique!(extension_combinations) + end + for extensions in extension_combinations + @info "Testing ambiguities with extensions: $extensions" + num_ambiguities, strout, strerr = + _find_ambiguities(packages; extensions = extensions, kwargs...) - print(stderr, strerr) - print(stdout, strout) + print(stderr, strerr) + print(stdout, strout) - if broken - @test_broken num_ambiguities == 0 - else - @test num_ambiguities == 0 + if broken + @test_broken num_ambiguities == 0 + else + @test num_ambiguities == 0 + end end + end function _find_ambiguities( packages::Vector{PkgId}; color::Union{Bool,Nothing} = nothing, exclude::AbstractVector = [], + extensions::AbstractVector = [], # Options to be passed to `Test.detect_ambiguities`: detect_ambiguities_options..., ) + packages = copy(packages) + extdeppackages = PkgId[] + @static if JULIA_HAS_EXTENSIONS + extensions = String.(extensions) + for ext in extensions + found = false + for pkg in setdiff(packages, [PkgId(Base), PkgId(Core)]) + exts, weakdeps, deps = get_extension_data_from_toml(pkg) + if haskey(exts, ext) + found = true + extdeps = exts[ext] isa String ? [exts[ext]] : exts[ext] + for extdepname in extdeps + if haskey(deps, extdepname) + push!(extdeppackages, deps[extdepname]) + elseif haskey(weakdeps, extdepname) + push!(extdeppackages, weakdeps[extdepname]) + else + error( + "Extension $ext depends on $extdepname, but it is not found.", + ) + end + end + push!(packages, PkgId(Base.uuid5(pkg.uuid, ext), ext)) + break + end + end + found && continue + error("Extension $ext is not found.") + end + end + packages_repr = reprpkgids(collect(packages)) + extdeppackages_repr = reprpkgids(collect(extdeppackages)) options_repr = checked_repr((; recursive = true, detect_ambiguities_options...)) exclude_repr = reprexclude(normalize_and_check_exclude(exclude)) @@ -115,6 +181,7 @@ function _find_ambiguities( using Aqua Aqua.test_ambiguities_impl( $packages_repr, + $extdeppackages_repr, $options_repr, $exclude_repr, ) || exit(1) @@ -143,7 +210,7 @@ end function reprpkgids(packages::Vector{PkgId}) packages_repr = sprint() do io - println(io, '[') + println(io, "Base.PkgId[") for pkg in packages println(io, reprpkgid(pkg)) end @@ -200,9 +267,11 @@ end function test_ambiguities_impl( packages::Vector{PkgId}, + extdeppackages::Vector{PkgId}, options::NamedTuple, exspecs::Vector{ExcludeSpec}, ) + deps = map(Base.require, extdeppackages) modules = map(Base.require, packages) @debug "Testing method ambiguities" modules ambiguities = detect_ambiguities(modules...; options...) diff --git a/src/piracy.jl b/src/piracy.jl index a552b623..34e53dd2 100644 --- a/src/piracy.jl +++ b/src/piracy.jl @@ -1,47 +1,60 @@ module Piracy -using Test: @test, @test_broken -using ..Aqua: walkmodules - -const DEFAULT_PKGS = (Base.PkgId(Base), Base.PkgId(Core)) - -function all_methods!( - mod::Module, - done_callables::Base.IdSet{Any}, # cached to prevent duplicates - result::Vector{Method}, - filter_default::Bool, -)::Vector{Method} - for name in names(mod; all = true, imported = true) - # names can list undefined symbols which cannot be eval'd - isdefined(mod, name) || continue - - # Skip closures - startswith(String(name), "#") && continue - val = getfield(mod, name) - - if !in(val, done_callables) - # In old versions of Julia, Vararg errors when methods is called on it - val === Vararg && continue - for method in methods(val) - # Default filtering removes all methods defined in DEFAULT_PKGs, - # since these may pirate each other. - if !(filter_default && in(Base.PkgId(method.module), DEFAULT_PKGS)) - push!(result, method) - end - end - push!(done_callables, val) +using Aqua: JULIA_HAS_EXTENSIONS + +if VERSION >= v"1.6-" + using Test: is_in_mods +else + function is_in_mods(m::Module, recursive::Bool, mods) + while true + m in mods && return true + recursive || return false + p = parentmodule(m) + p === m && return false + m = p end end - result end -function all_methods(mod::Module; filter_default::Bool = true) - result = Method[] - done_callables = Base.IdSet() - walkmodules(mod) do mod - all_methods!(mod, done_callables, result, filter_default) +# based on Test/Test.jl#detect_ambiguities +# https://github.com/JuliaLang/julia/blob/v1.9.1/stdlib/Test/src/Test.jl#L1838-L1896 +function all_methods(mods::Module...; skip_deprecated::Bool = true) + meths = Method[] + mods = collect(mods)::Vector{Module} + + function examine(mt::Core.MethodTable) + examine(Base.MethodList(mt)) + end + function examine(ml::Base.MethodList) + for m in ml + is_in_mods(m.module, true, mods) || continue + push!(meths, m) + end end - return result + + work = Base.loaded_modules_array() + filter!(mod -> mod === parentmodule(mod), work) # some items in loaded_modules_array are not top modules (really just Base) + while !isempty(work) + mod = pop!(work) + for name in names(mod; all = true) + (skip_deprecated && Base.isdeprecated(mod, name)) && continue + isdefined(mod, name) || continue + f = Base.unwrap_unionall(getfield(mod, name)) + if isa(f, Module) && f !== mod && parentmodule(f) === mod && nameof(f) === name + push!(work, f) + elseif isa(f, DataType) && + isdefined(f.name, :mt) && + parentmodule(f) === mod && + nameof(f) === name && + f.name.mt !== Symbol.name.mt && + f.name.mt !== DataType.name.mt + examine(f.name.mt) + end + end + end + examine(Symbol.name.mt) + examine(DataType.name.mt) + return meths end ################################## @@ -141,7 +154,7 @@ function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as # fallback to general code return !(T in treat_as_own) && - !(T <: Function && T.instance in treat_as_own) && + !(T <: Function && isdefined(T, :instance) && T.instance in treat_as_own) && is_foreign(T, pkg; treat_as_own = treat_as_own) end @@ -149,6 +162,11 @@ end function is_pirate(meth::Method; treat_as_own = Union{Function,Type}[]) method_pkg = Base.PkgId(meth.module) + # Package extensions behave as the package itself + @static if JULIA_HAS_EXTENSIONS + method_pkg = get(Base.EXT_PRIMED, method_pkg, method_pkg) + end + signature = Base.unwrap_unionall(meth.sig) # the first parameter in the signature is the function type, and it @@ -162,12 +180,9 @@ function is_pirate(meth::Method; treat_as_own = Union{Function,Type}[]) ) end -hunt(mod::Module; from::Module = mod, kwargs...) = - hunt(Base.PkgId(mod); from = from, kwargs...) - -function hunt(pkg::Base.PkgId; from::Module, kwargs...) - filter(all_methods(from)) do method - Base.PkgId(method.module) === pkg && is_pirate(method; kwargs...) +function hunt(mod::Module; skip_deprecated::Bool = true, kwargs...) + filter(all_methods(mod; skip_deprecated = skip_deprecated)) do method + method.module === mod && is_pirate(method; kwargs...) end end @@ -182,6 +197,7 @@ See [Julia documentation](https://docs.julialang.org/en/v1/manual/style-guide/#A # Keyword Arguments - `broken::Bool = false`: If true, it uses `@test_broken` instead of `@test`. +- `skip_deprecated::Bool = true`: If true, it does not check deprecated methods. - `treat_as_own = Union{Function, Type}[]`: The types in this container are considered to be "owned" by the module `m`. This is useful for testing packages that deliberately commit some type piracy, e.g. modules diff --git a/src/utils.jl b/src/utils.jl index 0179714e..0b1c57a4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -96,6 +96,23 @@ catch end end +function get_extension_data_from_toml(pkg::PkgId) + root_project_path = root_project_or_failed_lazytest(pkg) + root_project_path isa LazyTestResult && + return Dict{String,Any}(), Dict{String,Any}(), Dict{String,Any}() + + @debug "Parsing `$root_project_path`" + prj = TOML.parsefile(root_project_path) + raw_exts = get(prj, "extensions", Dict{String,Any}()) + + raw_weakdeps = get(prj, "weakdeps", Dict{String,Any}()) + weakdeps = Dict(name => PkgId(UUID(uuid), name) for (name, uuid) in raw_weakdeps) + + raw_deps = get(prj, "deps", Dict{String,Any}()) + deps = Dict(name => PkgId(UUID(uuid), name) for (name, uuid) in raw_deps) + return raw_exts, weakdeps, deps +end + const _project_key_order = [ "name", "uuid", diff --git a/test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl b/test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl index 03dea9e6..fc2c5c5e 100644 --- a/test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl +++ b/test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl @@ -3,4 +3,8 @@ module PiracyForeignProject struct ForeignType end struct ForeignParameterizedType{T} end +struct ForeignNonSingletonType + x::Int +end + end diff --git a/test/test_piracy.jl b/test/test_piracy.jl index f586c433..78de01d3 100644 --- a/test/test_piracy.jl +++ b/test/test_piracy.jl @@ -2,7 +2,7 @@ push!(LOAD_PATH, joinpath(@__DIR__, "pkgs", "PiracyForeignProject")) baremodule PiracyModule -using PiracyForeignProject: ForeignType, ForeignParameterizedType +using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType using Base: Base, @@ -44,6 +44,8 @@ export MyUnion Base.findfirst(::Set{Vector{Char}}, ::Int) = 1 Base.findfirst(::Union{Foo,Bar{Set{Unsigned}},UInt}, ::Tuple{Vararg{String}}) = 1 Base.findfirst(::AbstractChar, ::Set{T}) where {Int <: T <: Integer} = 1 +(::ForeignType)(x::Int8) = x + 1 +(::ForeignNonSingletonType)(x::Int8) = x + 1 # Piracy, but not for `ForeignType in treat_as_own` Base.findmax(::ForeignType, x::Int) = x + 1 @@ -55,29 +57,27 @@ Base.findmin(::ForeignParameterizedType{Int}, x::Int) = x + 1 Base.findmin(::Set{Vector{ForeignParameterizedType{Int}}}, x::Int) = x + 1 Base.findmin(::Union{Foo,ForeignParameterizedType{Int}}, x::Int) = x + 1 -# Assign them names in this module so they can be found by all_methods -a = Base.findfirst -b = Base.findlast -c = Base.findmax -d = Base.findmin end # PiracyModule using Aqua: Piracy -using PiracyForeignProject: ForeignType, ForeignParameterizedType +using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType # Get all methods - test length meths = filter(Piracy.all_methods(PiracyModule)) do m m.module == PiracyModule end -# 2 Foo constructors -# 2 from f -# 1 from MyUnion -# 6 from findlast -# 3 from findfirst -# 3 from findmax -# 3 from findmin -@test length(meths) == 2 + 2 + 1 + 6 + 3 + 3 + 3 +@test length(meths) == + 2 + # Foo constructors + 1 + # Bar constructor + 2 + # f + 1 + # MyUnion + 6 + # findlast + 3 + # findfirst + 1 + # ForeignType callable + 1 + # ForeignNonSingletonType callable + 3 + # findmax + 3 # findmin # Test what is foreign BasePkg = Base.PkgId(Base) @@ -90,24 +90,36 @@ ThisPkg = Base.PkgId(PiracyModule) @test !Piracy.is_foreign(Set{Int}, CorePkg; treat_as_own = []) # Test what is pirate -pirates = filter(m -> Piracy.is_pirate(m), meths) -@test length(pirates) == 3 + 3 + 3 +pirates = Piracy.hunt(PiracyModule) +@test length(pirates) == + 3 + # findfirst + 3 + # findmax + 3 + # findmin + 1 + # ForeignType callable + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findfirst, :findmax, :findmin] + m.name in [:findfirst, :findmax, :findmin, :ForeignType, :ForeignNonSingletonType] end # Test what is pirate (with treat_as_own=[ForeignType]) -pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignType]), meths) -@test length(pirates) == 3 + 3 +pirates = Piracy.hunt(PiracyModule, treat_as_own = [ForeignType]) +@test length(pirates) == + 3 + # findfirst + 3 + # findmin + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findfirst, :findmin] + m.name in [:findfirst, :findmin, :ForeignNonSingletonType] end # Test what is pirate (with treat_as_own=[ForeignParameterizedType]) -pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignParameterizedType]), meths) -@test length(pirates) == 3 + 3 +pirates = Piracy.hunt(PiracyModule, treat_as_own = [ForeignParameterizedType]) +@test length(pirates) == + 3 + # findfirst + 3 + # findmax + 1 + # ForeignType callable + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findfirst, :findmax] + m.name in [:findfirst, :findmax, :ForeignType, :ForeignNonSingletonType] end # Test what is pirate (with treat_as_own=[ForeignType, ForeignParameterizedType]) @@ -115,24 +127,33 @@ pirates = filter( m -> Piracy.is_pirate(m; treat_as_own = [ForeignType, ForeignParameterizedType]), meths, ) -@test length(pirates) == 3 +@test length(pirates) == + 3 + # findfirst + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findfirst] + m.name in [:findfirst, :ForeignNonSingletonType] end # Test what is pirate (with treat_as_own=[Base.findfirst, Base.findmax]) -pirates = - filter(m -> Piracy.is_pirate(m; treat_as_own = [Base.findfirst, Base.findmax]), meths) -@test length(pirates) == 3 +pirates = Piracy.hunt(PiracyModule, treat_as_own = [Base.findfirst, Base.findmax]) +@test length(pirates) == + 3 + # findmin + 1 + # ForeignType callable + 1 # ForeignNonSingletonType callable @test all(pirates) do m - m.name in [:findmin] + m.name in [:findmin, :ForeignType, :ForeignNonSingletonType] end # Test what is pirate (excluding a cover of everything) pirates = filter( m -> Piracy.is_pirate( m; - treat_as_own = [ForeignType, ForeignParameterizedType, Base.findfirst], + treat_as_own = [ + ForeignType, + ForeignParameterizedType, + ForeignNonSingletonType, + Base.findfirst, + ], ), meths, )