diff --git a/src/StructFieldParamsTesting.jl b/src/StructFieldParamsTesting.jl index 410a53e..4e510d1 100644 --- a/src/StructFieldParamsTesting.jl +++ b/src/StructFieldParamsTesting.jl @@ -1,7 +1,7 @@ module StructFieldParamsTesting -export test_all_structs_have_fully_specified_fields -export test_all_fields_fully_specified, field_is_fully_specified +export test_all_structs_have_fully_specified_fields, print_all_structs_have_fully_specified_fields +export test_all_fields_fully_specified, field_is_fully_specified, print_all_fields_fully_specified using MacroTools: @capture using Markdown: Markdown diff --git a/src/check_struct_fields.jl b/src/check_struct_fields.jl index 234dbf0..9aa1389 100644 --- a/src/check_struct_fields.jl +++ b/src/check_struct_fields.jl @@ -21,6 +21,147 @@ function field_is_fully_specified(pkg::Module, struct_expr, field_name; location report_error = false, location = location) end +function print_all_fields_fully_specified(pkg::Module, struct_expr; location = nothing) + (struct_name, T, fields_dict) = _extract_struct_field_types(pkg::Module, struct_expr) + + # Extract field order from original struct expression + field_order = _extract_field_order(struct_expr) + + # Check each field and collect unspecified ones + unspecified_fields = Dict{Symbol, Any}() + fully_specified_fields = Dict{Symbol, Any}() + corrected_fields = Dict{Symbol, Any}() + + for (field_name, field_type_expr) in fields_dict + is_specified = check_field_type_fully_specified( + pkg, struct_name, field_name, T, field_type_expr, + report_error = false, location = location) + + if is_specified + fully_specified_fields[field_name] = field_type_expr + corrected_fields[field_name] = field_type_expr + else + unspecified_fields[field_name] = field_type_expr + # Generate corrected type + try + TypeObj = Base.eval(pkg, quote + $(field_type_expr) where {$(T...)} + end) + complete_type = single_unwrap_unionall(TypeObj) + corrected_fields[field_name] = complete_type + catch e + corrected_fields[field_name] = field_type_expr + end + end + end + + # Print original struct definition + println("Original struct definition:") + if location !== nothing + printstyled("#= $location =#\n"; color=:light_black) + end + _print_struct_definition_colored(struct_name, T, fields_dict, field_order, struct_expr, unspecified_fields, :red) + println() + + # Only show corrected version if there are issues + if !isempty(unspecified_fields) + println("Fully-specified version:") + _print_struct_definition_colored(struct_name, T, corrected_fields, field_order, struct_expr, unspecified_fields, :green) + end +end + +function _extract_field_order(struct_expr) + @capture( + struct_expr, + struct name_{T__} <: S_ fields__ end | struct name_ <: S_ fields__ end | + struct name_{T__} fields__ end | struct name_ fields__ end | + mutable struct name_{T__} <: S_ fields__ end | mutable struct name_ <: S_ fields__ end | + mutable struct name_{T__} fields__ end | mutable struct name_ fields__ end + ) || error("Invalid struct expression: $(struct_expr)") + + fields_split = split_field.(fields) + filter!(x -> x !== nothing, fields_split) + return [f[1] for f in fields_split] +end + +function _print_struct_definition_colored(struct_name, T, fields_dict, field_order, original_expr, unspecified_fields, color) + # Determine if it's mutable by looking at the original expression + is_mutable = false + if original_expr isa Expr && original_expr.head == :struct && length(original_expr.args) >= 1 + is_mutable = original_expr.args[1] isa Bool ? original_expr.args[1] : false + end + + # Build struct definition string + prefix = is_mutable ? "mutable struct " : "struct " + + # Handle type parameters + if isempty(T) + struct_header = "$prefix$struct_name" + else + struct_header = "$prefix$struct_name{$(join(T, ", "))}" + end + + println(struct_header) + + # Print fields in original order with color for changed fields + for field_name in field_order + field_type_expr = fields_dict[field_name] + field_line = if field_type_expr == Any + " $field_name" + else + " $field_name::$field_type_expr" + end + + # Color the line if this field was unspecified + if field_name in keys(unspecified_fields) + if color == :red + printstyled(field_line, "\n"; color=:red) + elseif color == :green + printstyled(field_line, "\n"; color=:green) + else + println(field_line) + end + else + println(field_line) + end + end + + println("end") +end + +function _print_struct_definition(struct_name, T, fields_dict, field_order, original_expr) + # Determine if it's mutable by looking at the original expression + is_mutable = false + if original_expr isa Expr && original_expr.head == :struct && length(original_expr.args) >= 1 + is_mutable = original_expr.args[1] isa Bool ? original_expr.args[1] : false + end + + # Build struct definition string + prefix = is_mutable ? "mutable struct " : "struct " + + # Handle type parameters + if isempty(T) + struct_header = "$prefix$struct_name" + else + struct_header = "$prefix$struct_name{$(join(T, ", "))}" + end + + println(struct_header) + + # Print fields in original order + for field_name in field_order + field_type_expr = fields_dict[field_name] + if field_type_expr == Any + println(" $field_name") + else + println(" $field_name::$field_type_expr") + end + end + + println("end") +end + + function _extract_struct_field_types(pkg::Module, struct_expr) @capture( struct_expr, @@ -66,16 +207,14 @@ function check_field_type_fully_specified( @debug "Type is a DataType: $(TypeObj)" return true end - if typeof(TypeObj) == Union - # TODO: Handle every branch of the union. - # For now, just skip these fields. - return true - end if TypeObj == Type # TODO: FOR NOW, to avoid noisy result return true end - @assert typeof(TypeObj) === UnionAll "$(TypeObj) is not a UnionAll. Got $(typeof(TypeObj))." + @assert ( + typeof(TypeObj) === UnionAll || + typeof(TypeObj) === Union + ) "$(TypeObj) is not a UnionAll. Got $(typeof(TypeObj))." num_type_params = _count_unionall_free_parameters(TypeObj) num_expr_args = _count_type_expr_params(mod, field_type_expr) @@ -90,6 +229,14 @@ function check_field_type_fully_specified( end return success end +function check_field_type_fully_specified(mod, TypeObj, field_type_expr) + num_type_params = _count_unionall_free_parameters(TypeObj) + num_expr_args = _count_type_expr_params(mod, field_type_expr) + # "Less than or equal to" in order to support fully constrained parameters in the expr. + # E.g.: `Vector{T} where T<:Int` has 0 free type params but 1 param in the expression. + success = num_type_params <= num_expr_args + return success +end # TODO(type-alias): What do we actually want to do for alias types? # E.g. @@ -103,9 +250,14 @@ recursive_unwrap_unionall(@nospecialize(T)) = Base.unwrap_unionall(T) recursive_unwrap_unionall(T::UnionAll) = recursive_unwrap_unionall(Base.unwrap_unionall(T)) recursive_unwrap_unionall(T::Union) = Union{recursive_unwrap_unionall(T.a), recursive_unwrap_unionall(T.b)} +single_unwrap_unionall(@nospecialize(T)) = Base.unwrap_unionall(T) +single_unwrap_unionall(T::Union) = Union{single_unwrap_unionall(T.a), single_unwrap_unionall(T.b)} + # Get free TypeVar names (without constraints): # Foo{Int, X<:Integer, Y} where {X, Y} => [:X, :Y] -function type_param_names(TypeObj) +type_param_names(TypeObj) = Symbol[] +type_param_names(TypeObj::Union) = Symbol[type_param_names(TypeObj.a)..., type_param_names(TypeObj.b)...] +function type_param_names(TypeObj::UnionAll) names = Symbol[] while typeof(TypeObj) === UnionAll push!(names, TypeObj.var.name) @@ -162,13 +314,13 @@ function field_type_not_complete_message( @assert num_type_params <= length(type_params) type_params = type_params[1:num_type_params] expr_args = num_expr_args == 0 ? Symbol[] : type_params[1:num_expr_args] - complete_type = Base.unwrap_unionall(TypeObj) + complete_type = single_unwrap_unionall(TypeObj) # TODO(type-alias): see comment on recursive_unwrap_unionall complete_type_recursive = recursive_unwrap_unionall(TypeObj) s = num_type_params == 1 ? "" : "s" - print(io, """ - In struct `$(mod).$(struct_name)`, the field `$(field_name)` does not have a fully \ - specified type:\n + print(io, """\n + In struct `$(mod).$(struct_name)`, + the field `$(field_name)` does not have a fully specified type:\n \t$(field_name)::$(field_type_expr)\n """ ) @@ -181,7 +333,7 @@ function field_type_not_complete_message( print(io, """ The complete type is:\n \t$(complete_type)\n - which expects $(num_type_params) type parameter$(s): `$(join(type_params, ", "))`.\n + which expects $(num_type_params) additional type parameter$(s): `$(join(type_params, ", "))`.\n """) if string(complete_type) != string(complete_type_recursive) print(io, """ @@ -189,36 +341,34 @@ function field_type_not_complete_message( \t$(complete_type_recursive)\n """) end - print(io, """ - The current definition `$(field_type_expr)` specifies \ - $(num_expr_args == 0 ? "no type parameters." : "only $(num_expr_args) type parameters: \ - `$(join(expr_args, ", "))`.") - """) - - print(io, """ - This means the `$(field_name)` field currently has an abstract type, - and any access to it is type unstable will therefore cause a dynamic dispatch. + This means the `$(field_name)` field currently has an abstract type, and any access to it, + like `x.$(field_name)`, is type unstable and will therefore cause a dynamic dispatch. - If this was a mistake, possibly caused by a change to the `$(typename)` type that \ - introduced new parameters to it, please make sure that your field `$(field_name)` is \ + If this was a mistake, possibly caused by a change to the `$(typename)` type that + introduced new parameters to it, please make sure that your field `$(field_name)` is fully concrete, with all parameters specified. - If, instead, this type instability is on purpose, please fully specify the omitted \ - type parameters to silence this message. You can write that as `$(complete_type)`, or \ - possibly in a shorter alias form which this message can't always detect. (E.g. you can \ + If, instead, this type instability is on purpose, please fully specify the omitted + type parameters to silence this message. You can write that as `$(complete_type)`, or + possibly in a shorter alias form which this message can't always detect. (E.g. you can write `Vector{T} where T` instead of `Array{T, 1} where T`.) """) return io end _count_unionall_free_parameters(@nospecialize(::Any)) = 0 +function _count_unionall_free_parameters(TypeObj::Union) + aa = TypeObj.a + bb = TypeObj.b + return _count_unionall_free_parameters(aa) + _count_unionall_free_parameters(bb) +end function _count_unionall_free_parameters(TypeObj::UnionAll) return _count_unionall_free_parameters(Base.unwrap_unionall(TypeObj)) end function _count_unionall_free_parameters(TypeObj::DataType) count = 0 - for param in @show TypeObj.parameters + for param in TypeObj.parameters # only `TypeVars` can be free parameters, but in `T<:ConcreteType` # don't consider `T` as a free parameter if param isa TypeVar && !isconcretetype(param.ub) diff --git a/src/whole_module_checks.jl b/src/whole_module_checks.jl index af3c0f0..6ec940a 100644 --- a/src/whole_module_checks.jl +++ b/src/whole_module_checks.jl @@ -1,18 +1,67 @@ +mutable struct StructStats + total::Int + unspecified::Int + + StructStats() = new(0, 0) +end + function test_all_structs_have_fully_specified_fields(pkg::Module; verbose::Bool=false) - @testset "$(pkg)" verbose=verbose begin - dir = pkgdir(pkg) - @assert !isnothing(dir) "No file found for Module `$(pkg)`." - entrypoint = joinpath(dir, "src", "$(nameof(pkg)).jl") - @assert ispath(entrypoint) "Package $(pkg) source not found: $entrypoint" - test_file(pkg, entrypoint) + dir = pkgdir(pkg) + @assert !isnothing(dir) "No file found for Module `$(pkg)`." + entrypoint = joinpath(dir, "src", "$(nameof(pkg)).jl") + @assert ispath(entrypoint) "Package $(pkg) source not found: $entrypoint" + @testset "$(pkg).jl" verbose=verbose begin + test_all_structs_have_fully_specified_fields(pkg, entrypoint; verbose) end end -function test_file(pkg::Module, filename::String) +function print_all_structs_have_fully_specified_fields(pkg::Module) + dir = pkgdir(pkg) + @assert !isnothing(dir) "No file found for Module `$(pkg)`." + entrypoint = joinpath(dir, "src", "$(nameof(pkg)).jl") + @assert ispath(entrypoint) "Package $(pkg) source not found: $entrypoint" + + # Use a shared stats object across all files + struct_stats = StructStats() + print_all_structs_have_fully_specified_fields(pkg, entrypoint, struct_stats) + + # Print summary only once at the end + if struct_stats.total > 0 + fully_specified = struct_stats.total - struct_stats.unspecified + println("Module `$(pkg)` has $(fully_specified)/$(struct_stats.total) structs with fully specified fields") + end +end + +function test_all_structs_have_fully_specified_fields( + mod::Module, filename::AbstractString; verbose::Bool=false +) # Parse the file and call handle_parsed_expression on each expression. contents = read(filename, String) walk_string(contents, filename) do parsed, loc - handle_parsed_expression(pkg, parsed, loc) + handle_parsed_expression(mod, parsed, loc; verbose) + end +end + +function print_all_structs_have_fully_specified_fields( + mod::Module, filename::AbstractString, struct_stats::Union{StructStats, Nothing} = nothing +) + # Parse the file and call handle_parsed_expression_print on each expression. + contents = read(filename, String) + + # If no stats provided, create one and mark that we should print summary + should_print_summary = struct_stats === nothing + if struct_stats === nothing + struct_stats = StructStats() + end + + walk_string(contents, filename) do parsed, loc + handle_parsed_expression_print(mod, parsed, loc, struct_stats) + end + + # Print summary only if this was the top-level call + if should_print_summary && struct_stats.total > 0 + fully_specified = struct_stats.total - struct_stats.unspecified + println("Module `$(mod)` has $(fully_specified)/$(struct_stats.total) structs with fully specified fields") end end @@ -38,28 +87,78 @@ function walk_string(mapexpr::Function, code::AbstractString, filename::Abstract end end -handle_parsed_expression(::Module, x::Any, _loc) = nothing -function handle_parsed_expression(pkg::Module, parsed::Expr, loc) +handle_parsed_expression(::Module, x::Any, _loc; kw...) = nothing +function handle_parsed_expression(mod::Module, parsed::Expr, loc; verbose::Bool=false) # @show loc if parsed.head == :struct # DO THE THING location = "$(loc.file):$(loc.line)" - test_all_fields_fully_specified(pkg, parsed; location) + test_all_fields_fully_specified(mod, parsed; location) elseif parsed.head == :call && parsed.args[1] == :include # Follow includes to more files new_file = joinpath(dirname(String(loc.file)), parsed.args[2]) - return test_file(pkg, new_file) + return test_all_structs_have_fully_specified_fields(mod, new_file) elseif parsed.head == :module modname = parsed.args[2] - inner_mod = Core.eval(pkg, modname) - @testset "$(inner_mod)" begin + inner_mod = Core.eval(mod, modname) + @testset "$(inner_mod)" verbose=verbose begin for expr in parsed.args - handle_parsed_expression(inner_mod, expr, loc) + handle_parsed_expression(inner_mod, expr, loc; verbose) + end + end + else + for expr in parsed.args + handle_parsed_expression(mod, expr, loc) + end + end +end + +handle_parsed_expression_print(::Module, x::Any, _loc, stats) = nothing +function handle_parsed_expression_print(mod::Module, parsed::Expr, loc, stats) + if parsed.head == :struct + # Count all structs and check if they have unspecified fields + location = "$(loc.file):$(loc.line)" + + # Check if struct has any unspecified fields first + (struct_name, T, fields_dict) = StructFieldParamsTesting._extract_struct_field_types(mod, parsed) + has_unspecified = false + + for (field_name, field_type_expr) in fields_dict + is_specified = StructFieldParamsTesting.check_field_type_fully_specified( + mod, struct_name, field_name, T, field_type_expr, + report_error = false, location = location) + + if !is_specified + has_unspecified = true + break end end + + # Update statistics + stats.total += 1 + if has_unspecified + stats.unspecified += 1 + end + + # Only print if there are issues + if has_unspecified + print_all_fields_fully_specified(mod, parsed; location) + println() # Add spacing between structs + end + elseif parsed.head == :call && parsed.args[1] == :include + # Follow includes to more files + new_file = joinpath(dirname(String(loc.file)), parsed.args[2]) + return print_all_structs_have_fully_specified_fields(mod, new_file, stats) + elseif parsed.head == :module + modname = parsed.args[2] + inner_mod = Core.eval(mod, modname) + println("=== Module: $(inner_mod) ===") + for expr in parsed.args + handle_parsed_expression_print(inner_mod, expr, loc, stats) + end else for expr in parsed.args - handle_parsed_expression(pkg, expr, loc) + handle_parsed_expression_print(mod, expr, loc, stats) end end end