diff --git a/deps/ReactantExtra/make-proto-bindings.jl b/deps/ReactantExtra/make-proto-bindings.jl index f159cd233b..167472d8e7 100644 --- a/deps/ReactantExtra/make-proto-bindings.jl +++ b/deps/ReactantExtra/make-proto-bindings.jl @@ -165,6 +165,12 @@ function generate_bindings(staging_dir::String, output_dir::String) # Remove headers from generated files to minimize diffs remove_proto_headers(output_dir) + # Convert all structs to mutable structs + make_structs_mutable(output_dir) + + # Convert large structs to use dict-based storage + convert_large_structs_to_dict(output_dir) + # Create a main module that includes all generated files write_main_module(output_dir, proto_rel_paths) @@ -203,6 +209,216 @@ function remove_proto_headers(output_dir::String) end end +""" + make_structs_mutable(output_dir::String) + +Convert all `struct` declarations to `mutable struct` in generated Julia files. +This allows proto message types to be modified after construction. +""" +function make_structs_mutable(output_dir::String) + println("\n Converting structs to mutable structs...") + for (root, dirs, files) in walkdir(output_dir) + for file in files + endswith(file, ".jl") || continue + file == "Proto.jl" && continue + + path = joinpath(root, file) + content = read(path, String) + + # Replace "struct " with "mutable struct " but avoid double-replacement + # Use word boundary to avoid replacing inside other words + new_content = replace(content, r"\bstruct\s+" => "mutable struct ") + + if content != new_content + write(path, new_content) + end + end + end +end + +""" + convert_large_structs_to_dict(output_dir::String; min_fields::Int=8) + +Convert structs with more than `min_fields` to use a single Dict{Symbol,Any} storage. +This reduces compile time and memory usage for large proto structs. + +Also comments out `PB.default_values` and `PB.field_numbers` for these structs, +and updates the constructor to use kwargs that populate the dict. +""" +function convert_large_structs_to_dict(output_dir::String; min_fields::Int=8) + println("\n Converting large structs to dict-based storage...") + + for (root, dirs, files) in walkdir(output_dir) + for file in files + endswith(file, ".jl") || continue + file == "Proto.jl" && continue + + path = joinpath(root, file) + content = read(path, String) + original_content = content + + # Track which structs we convert + converted_structs = String[] + + # Find all struct definitions with their fields + # Pattern: mutable struct Name\n field1::Type1\n field2::Type2\n...\nend + struct_pattern = r"(mutable struct\s+)(var\"[^\"]+\"|[A-Za-z_][A-Za-z0-9_]*)\s*\n((?:\s+[a-z_][a-z0-9_]*::[^\n]+\n)+)end" + + for m in eachmatch(struct_pattern, content) + struct_prefix = m.captures[1] + struct_name = m.captures[2] + fields_block = m.captures[3] + + # Parse field names and types + field_matches = collect( + eachmatch(r"\s+([a-z_][a-z0-9_]*)::([^\n]+)", fields_block) + ) + num_fields = length(field_matches) + + if num_fields > min_fields + push!(converted_structs, struct_name) + field_names = [fm.captures[1] for fm in field_matches] + field_types = [strip(fm.captures[2]) for fm in field_matches] + + # Build default values dict for getproperty + defaults_entries = [" :$(fn) => $(get_default_for_type(ft))" for (fn, ft) in zip(field_names, field_types)] + # Sanitize struct name for use as variable name + # Handle var"Foo.Bar" style names by removing var, quotes, and replacing dots + safe_name = struct_name + safe_name = replace(safe_name, "var\"" => "") + safe_name = replace(safe_name, "\"" => "") + safe_name = replace(safe_name, "." => "_") + + # Build new struct with dict storage + new_struct = """$(struct_prefix)$(struct_name) + __data::Dict{Symbol,Any} +end + +# Default values for $(struct_name) fields +const _$(safe_name)_defaults = Dict{Symbol,Any}( +$(join(defaults_entries, ",\n")) +) + +# Keyword constructor for $(struct_name) +function $(struct_name)(; kwargs...) + __data = Dict{Symbol,Any}(kwargs) + return $(struct_name)(__data) +end + +# Field accessors for $(struct_name) +function Base.getproperty(x::$(struct_name), s::Symbol) + s === :__data && return getfield(x, :__data) + d = getfield(x, :__data) + return get(d, s, get(_$(safe_name)_defaults, s, nothing)) +end +function Base.setproperty!(x::$(struct_name), s::Symbol, v) + getfield(x, :__data)[s] = v +end +Base.propertynames(::$(struct_name)) = ($(join([":$(fn)" for fn in field_names], ", ")),)""" + + # Replace the struct definition + content = replace(content, m.match => new_struct) + + # Also need to replace positional constructor calls in PB.decode + # Pattern: return StructName(arg1, arg2, ..., argN) + # Find the return statement for this struct type + # The args may contain [], so we need to be careful + constructor_pattern = Regex( + "return $(replace(struct_name, "\"" => "\\\""))\\(([^)]+)\\)" + ) + for cm in eachmatch(constructor_pattern, content) + args_str = cm.captures[1] + # Split by comma, but be careful with nested brackets + args = split_args(args_str) + if length(args) == num_fields + # Build keyword constructor call + kwargs = join(["$(fn)=$(strip(arg))" for (fn, arg) in zip(field_names, args)], ", ") + new_call = "return $(struct_name)(; $(kwargs))" + content = replace(content, cm.match => new_call) + end + end + end + end + + # Comment out PB.default_values and PB.field_numbers only for converted structs + for struct_name in converted_structs + # Escape special regex characters in struct name (for var"..." names) + escaped_name = replace(struct_name, "\"" => "\\\"") + escaped_name = replace(escaped_name, "." => "\\.") + + # Comment out PB.default_values for this struct + dv_pattern = Regex( + "^(PB\\.default_values\\(::Type\\{$(escaped_name)\\}\\).*?)\$", "m" + ) + content = replace(content, dv_pattern => s"# \1") + + # Comment out PB.field_numbers for this struct + fn_pattern = Regex( + "^(PB\\.field_numbers\\(::Type\\{$(escaped_name)\\}\\).*?)\$", "m" + ) + content = replace(content, fn_pattern => s"# \1") + + # Comment out PB.reserved_fields for this struct + rf_pattern = Regex( + "^(PB\\.reserved_fields\\(::Type\\{$(escaped_name)\\}\\).*?)\$", "m" + ) + content = replace(content, rf_pattern => s"# \1") + end + + if content != original_content + write(path, content) + end + end + end +end + +# Helper function to split arguments by comma, respecting nested brackets +function split_args(s::AbstractString) + args = String[] + current = IOBuffer() + depth = 0 + for c in s + if c in ('(', '[', '{') + depth += 1 + write(current, c) + elseif c in (')', ']', '}') + depth -= 1 + write(current, c) + elseif c == ',' && depth == 0 + push!(args, String(take!(current))) + else + write(current, c) + end + end + remaining = String(take!(current)) + if !isempty(strip(remaining)) + push!(args, remaining) + end + return args +end + +# Helper function to get a default value expression for a type +function get_default_for_type(type_str::AbstractString) + type_str = strip(type_str) + if startswith(type_str, "Vector{") + return "$(type_str)()" + elseif startswith(type_str, "Dict{") + return "$(type_str)()" + elseif startswith(type_str, "Union{Nothing,") + return "nothing" + elseif type_str == "String" + return "\"\"" + elseif type_str == "Bool" + return "false" + elseif type_str in ("Int32", "Int64", "UInt32", "UInt64", "Float32", "Float64") + return "zero($(type_str))" + elseif endswith(type_str, ".T") # Enum types + return "nothing" # Can't easily determine default enum value + else + return "nothing" + end +end + """ write_main_module(output_dir::String, generated::Vector{String})