diff --git a/Project.toml b/Project.toml index 820336a319..2d8409a575 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" +ScopedSettings = "6ffd3f19-5aa5-475d-9277-f0318686a530" ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63" Scratch = "6c6a2e73-6563-6170-7368-637461726353" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -49,6 +50,7 @@ YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" [sources] ReactantCore = {path = "lib/ReactantCore"} +ScopedSettings = {url = "https://github.com/avik-pal/ScopedSettings.jl", rev = "ap/union_types"} [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" @@ -101,6 +103,7 @@ Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.15" Reactant_jll = "0.0.237" +ScopedSettings = "0.1.1" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/docs/src/api/config.md b/docs/src/api/config.md index a3915c078c..d6ffcc33a0 100644 --- a/docs/src/api/config.md +++ b/docs/src/api/config.md @@ -6,11 +6,6 @@ CollapsedDocStrings = true ## Scoped Values -!!! warning - - Currently options are scattered in the form of global variables and scoped values. We - are in the process of migrating all of them into scoped values. - ```@docs Reactant.with_config ``` diff --git a/src/Compiler.jl b/src/Compiler.jl index 2932390a5f..53743ce4ca 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -27,9 +27,42 @@ import Reactant: OptimizeCommunicationOptions, ShardyPropagationOptions, Compile import ..ReactantCore: correct_maybe_bcast_call -const DEBUG_PRINT_CODEGEN = Ref(false) -const DEBUG_DISABLE_RESHARDING = Ref(false) -const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = Ref(false) +using ScopedSettings: ScopedSetting, GetPreference + +const DEBUG_PRINT_CODEGEN = ScopedSetting( + GetPreference(Reactant, "debug_print_codegen", false) +) +const DEBUG_DISABLE_RESHARDING = ScopedSetting( + GetPreference(Reactant, "debug_disable_resharding", false) +) +const DEBUG_ALIASED_BUFFER_ASSIGNMENT_ERROR = ScopedSetting( + GetPreference(Reactant, "debug_aliased_buffer_assignment_error", false) +) +const DEBUG_KERNEL = ScopedSetting(GetPreference(Reactant, "debug_kernel", false)) +const DUMP_LLVMIR = ScopedSetting(GetPreference(Reactant, "debug_dump_llvmir", false)) +const DUMP_FAILED_LOCKSTEP = ScopedSetting( + GetPreference(Reactant, "debug_dump_failed_lockstep", false) +) +const SROA_ATTRIBUTOR = ScopedSetting(GetPreference(Reactant, "sroa_attributor", false)) + +const WHILE_CONCAT = ScopedSetting(GetPreference(Reactant, "while_concat_passes", false)) +const DUS_TO_CONCAT = ScopedSetting(GetPreference(Reactant, "dus_to_concat_passes", false)) +const SUM_TO_REDUCEWINDOW = ScopedSetting( + GetPreference(Reactant, "sum_to_reducewindow_passes", false) +) +const SUM_TO_CONV = ScopedSetting(GetPreference(Reactant, "sum_to_conv_passes", false)) +const AGGRESSIVE_SUM_TO_CONV = ScopedSetting( + GetPreference(Reactant, "aggressive_sum_to_conv_passes", false) +) +const AGGRESSIVE_PROPAGATION = ScopedSetting( + GetPreference(Reactant, "aggressive_propagation_passes", false) +) +const DUS_SLICE_SIMPLIFY = ScopedSetting( + GetPreference(Reactant, "dus_slice_simplify_passes", true) +) +const CONCATS_TO_DUS = ScopedSetting(GetPreference(Reactant, "concats_to_dus_passes", true)) + +const OpenMP = ScopedSetting(GetPreference(Reactant, "lower_jit_to_openmp", true)) const DEBUG_BUFFER_POINTERS_STORE_DICT = Base.IdDict() @@ -684,15 +717,6 @@ function create_result( return Meta.quot(tocopy) end -const WHILE_CONCAT = Ref(false) -const DUS_TO_CONCAT = Ref(false) -const SUM_TO_REDUCEWINDOW = Ref(false) -const SUM_TO_CONV = Ref(false) -const AGGRESSIVE_SUM_TO_CONV = Ref(false) -const AGGRESSIVE_PROPAGATION = Ref(false) -const DUS_SLICE_SIMPLIFY = Ref(true) -const CONCATS_TO_DUS = Ref(false) - # Optimization passes via transform dialect function optimization_passes( compile_options::CompileOptions; @@ -1436,12 +1460,6 @@ function cubinFeatures() return "+ptx$ptx" end -const DEBUG_KERNEL = Ref{Bool}(false) -const DUMP_LLVMIR = Ref{Bool}(false) -const DUMP_FAILED_LOCKSTEP = Ref{Bool}(false) -const OpenMP = Ref{Bool}(true) -const SROA_ATTRIBUTOR = Ref{Bool}(true) - function activate_raising!(is_raising::Bool) stack = get!(task_local_storage(), :reactant_is_raising) do Bool[] diff --git a/src/Configuration.jl b/src/Configuration.jl index 5b0eaa00af..11a81f9c17 100644 --- a/src/Configuration.jl +++ b/src/Configuration.jl @@ -1,4 +1,4 @@ -using ScopedValues: ScopedValues, ScopedValue +using ScopedValues: ScopedValues export with_config export DotGeneralAlgorithmPreset, PrecisionConfig, DotGeneralAlgorithm @@ -63,8 +63,12 @@ function with_config( end # Lower to ApproxTopK -const LOWER_PARTIALSORT_TO_APPROX_TOP_K = ScopedValue(false) -const FALLBACK_APPROX_TOP_K_LOWERING = ScopedValue(true) +const LOWER_PARTIALSORT_TO_APPROX_TOP_K = ScopedSetting( + GetPreference(Reactant, "lower_partialsort_to_approx_top_k", false) +) +const FALLBACK_APPROX_TOP_K_LOWERING = ScopedSetting( + GetPreference(Reactant, "fallback_approx_top_k_lowering", true) +) # DotGeneral Attributes Configuration """ @@ -88,13 +92,13 @@ end Base.@deprecate_binding DotGeneralPrecision PrecisionConfig -const DOT_GENERAL_PRECISION = ScopedValue{ +const DOT_GENERAL_PRECISION = ScopedSetting{ Union{PrecisionConfig.T,Nothing,Tuple{PrecisionConfig.T,PrecisionConfig.T}} }( PrecisionConfig.DEFAULT ) -const CONVOLUTION_PRECISION = ScopedValue{ +const CONVOLUTION_PRECISION = ScopedSetting{ Union{PrecisionConfig.T,Nothing,Tuple{PrecisionConfig.T,PrecisionConfig.T}} }( PrecisionConfig.DEFAULT @@ -224,7 +228,7 @@ The following functions are available: TF32_TF32_F32_X3 end -const DOT_GENERAL_ALGORITHM = ScopedValue{ +const DOT_GENERAL_ALGORITHM = ScopedSetting{ Union{DotGeneralAlgorithmPreset.T,Nothing,DotGeneralAlgorithm} }( DotGeneralAlgorithmPreset.DEFAULT diff --git a/src/Ops.jl b/src/Ops.jl index 805126ab46..dfa1fbf054 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -99,10 +99,16 @@ function mlir_type(::Type{<:MissingTracedValue})::MLIR.IR.Type return MLIR.IR.TensorType(Int[], MLIR.IR.Type(Bool)) end -const DEBUG_MODE::Ref{Bool} = Ref(false) -const LARGE_CONSTANT_THRESHOLD = Ref(100 << 20) # 100 MiB -const LARGE_CONSTANT_RAISE_ERROR = Ref(true) -const GATHER_GETINDEX_DISABLED = Ref(false) +const DEBUG_MODE = ScopedSetting(false) +const LARGE_CONSTANT_THRESHOLD = ScopedSetting( + GetPreference(Reactant, "large_constant_threshold", 100 << 20) # 100 MiB +) +const LARGE_CONSTANT_RAISE_ERROR = ScopedSetting( + GetPreference(Reactant, "large_constant_raise_error", true) +) +const GATHER_GETINDEX_DISABLED = ScopedSetting( + GetPreference(Reactant, "gather_getindex_disabled", false) +) function with_debug(f) old = DEBUG_MODE[] diff --git a/src/Reactant.jl b/src/Reactant.jl index d6da2c92dd..4100846f85 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -11,6 +11,8 @@ using Functors: Functors, @leaf using Adapt: Adapt, WrappedArray using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)` +using ScopedSettings: ScopedSetting, GetPreference + export @allowscalar # re-exported from GPUArraysCore is_extension_loaded(::Val) = false diff --git a/src/mlir/IR/IR.jl b/src/mlir/IR/IR.jl index 3d429aeaf5..e948e6b25b 100644 --- a/src/mlir/IR/IR.jl +++ b/src/mlir/IR/IR.jl @@ -26,6 +26,8 @@ export @affinemap using Random: randstring +using ScopedSettings: ScopedSetting, GetPreference + function mlirIsNull(val) return val.ptr == C_NULL end diff --git a/src/mlir/IR/Pass.jl b/src/mlir/IR/Pass.jl index d0e69d7fdb..2523716dd1 100644 --- a/src/mlir/IR/Pass.jl +++ b/src/mlir/IR/Pass.jl @@ -65,9 +65,11 @@ function enable_verifier!(pm, enable=true) end # Where to dump the MLIR modules -const DUMP_MLIR_DIR = Ref{Union{Nothing,String}}(nothing) +const DUMP_MLIR_DIR = ScopedSetting{Union{Nothing,String}}( + GetPreference(Reactant, "dump_mlir_dir", nothing) +) # Whether to always dump MLIR, regardless of failure -const DUMP_MLIR_ALWAYS = Ref{Bool}(false) +const DUMP_MLIR_ALWAYS = ScopedSetting(GetPreference(Reactant, "dump_mlir_always", false)) # Counter for dumping MLIR modules const MLIR_DUMP_COUNTER = Threads.Atomic{Int}(0) diff --git a/src/utils.jl b/src/utils.jl index bed592af38..882736efdc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -525,7 +525,7 @@ function safe_print(name, x) return ccall(:jl_, Cvoid, (Any,), name * " " * string(x)) end -const DEBUG_INTERP = Ref(false) +const DEBUG_INTERP = ScopedSetting(GetPreference(Reactant, "debug_interpreter", false)) # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type