diff --git a/.github/workflows/Compile.yml b/.github/workflows/Compile.yml index 720b808a..96a09cce 100644 --- a/.github/workflows/Compile.yml +++ b/.github/workflows/Compile.yml @@ -6,6 +6,7 @@ on: - '.github/workflows/Compile.yml' - '.github/workflows/CompileOrRun.yml' - 'sharding/sharded_baroclinic_instability_simulation_compile.jl' + - 'sharding/sharded_ocean_climate_simulation_compile.jl' - 'simulations/baroclinic_instability_simulation_compile.jl' - 'simulations/ocean_climate_simulation_compile.jl' - 'Project.toml' @@ -19,6 +20,7 @@ on: - '.github/workflows/Compile.yml' - '.github/workflows/CompileOrRun.yml' - 'sharding/sharded_baroclinic_instability_simulation_compile.jl' + - 'sharding/sharded_ocean_climate_simulation_compile.jl' - 'simulations/baroclinic_instability_simulation_compile.jl' - 'simulations/ocean_climate_simulation_compile.jl' - 'Project.toml' @@ -63,16 +65,16 @@ jobs: julia_optlevel: 0 compile_sharded: - name: Sharded - Julia ${{ matrix.julia_version }} - ${{ matrix.grid_type }} - ${{ matrix.os }} + name: Sharded - Julia 1.11 - ${{ matrix.sim_type }} - ${{ matrix.os }} strategy: fail-fast: false matrix: os: ['ubuntu-24.04'] - grid_type: ['simple_lat_lon'] + sim_type: ['sharded_baroclinic_instability', 'sharded_ocean_climate'] uses: ./.github/workflows/CompileOrRun.yml with: - sim_type: 'sharded_baroclinic_instability' - grid_type: ${{ matrix.grid_type }} + sim_type: ${{ matrix.sim_type }} + grid_type: ${{ matrix.sim_type == 'sharded_baroclinic_instability' && 'simple_lat_lon' || 'simple_tripolar' }} sharded: true run_dir: 'sharding' julia_version: '1.11' diff --git a/.github/workflows/Run.yml b/.github/workflows/Run.yml index 6378abfd..ef9e2d58 100644 --- a/.github/workflows/Run.yml +++ b/.github/workflows/Run.yml @@ -6,6 +6,7 @@ on: - '.github/workflows/Run.yml' - '.github/workflows/CompileOrRun.yml' - 'sharding/sharded_baroclinic_instability_simulation_run.jl' + - 'sharding/sharded_ocean_climate_simulation_run.jl' - 'simulations/baroclinic_instability_simulation_run.jl' - 'simulations/ocean_climate_simulation_run.jl' - 'Project.toml' @@ -19,6 +20,7 @@ on: - '.github/workflows/Run.yml' - '.github/workflows/CompileOrRun.yml' - 'sharding/sharded_baroclinic_instability_simulation_run.jl' + - 'sharding/sharded_ocean_climate_simulation_run.jl' - 'simulations/baroclinic_instability_simulation_run.jl' - 'simulations/ocean_climate_simulation_run.jl' - 'Project.toml' @@ -61,16 +63,16 @@ jobs: julia_optlevel: 0 run_sharded: - name: Sharded - Julia ${{ matrix.julia_version }} - ${{ matrix.grid_type }} - ${{ matrix.os }} + name: Sharded - Julia 1.11 - ${{ matrix.sim_type }} - ${{ matrix.os }} strategy: fail-fast: false matrix: os: ['ubuntu-24.04', 'ubuntu-22.04-arm'] - grid_type: ['simple_lat_lon'] + sim_type: ['sharded_baroclinic_instability', 'sharded_ocean_climate'] uses: ./.github/workflows/CompileOrRun.yml with: - sim_type: 'sharded_baroclinic_instability' - grid_type: ${{ matrix.grid_type }} + sim_type: ${{ matrix.sim_type }} + grid_type: ${{ matrix.sim_type == 'sharded_baroclinic_instability' && 'simple_lat_lon' || 'simple_tripolar' }} sharded: true run_dir: 'sharding' julia_version: '1.11' diff --git a/sharding/sharded_ocean_climate_simulation_compile.jl b/sharding/sharded_ocean_climate_simulation_compile.jl new file mode 100644 index 00000000..9a1f46f0 --- /dev/null +++ b/sharding/sharded_ocean_climate_simulation_compile.jl @@ -0,0 +1,96 @@ +using ArgParse + +const args_settings = ArgParseSettings() +@add_arg_table! args_settings begin + "--grid-x" + help = "Base factor for number of grid points on the x axis." + default = 64 + arg_type = Int + "--grid-y" + help = "Base factor for number of grid points on the y axis." + default = 64 + arg_type = Int + "--grid-z" + help = "Base factor for number of grid points on the z axis." + default = 4 + arg_type = Int +end +const parsed_args = parse_args(ARGS, args_settings) + +using GordonBell25: first_time_step!, loop!, try_compile_code, preamble, TRY_COMPILE_FAILED +using GordonBell25: data_free_ocean_climate_model_init, PROFILE, GordonBell25 +using Reactant +using Oceananigans +using Oceananigans.Architectures: ReactantState +Reactant.Compiler.WHILE_CONCAT[] = true + +PROFILE[] = true + +preamble() + +GordonBell25.initialize(; single_gpu_per_process=false) +@show Ndev = length(Reactant.devices()) + +Rx, Ry = GordonBell25.factors(Ndev) +if Ndev == 1 + rank = 0 + arch = Oceananigans.ReactantState() +else + arch = Oceananigans.Distributed( + Oceananigans.ReactantState(); + partition = Partition(Rx, Ry, 1) + ) + rank = Reactant.Distributed.local_rank() +end + +H = 8 +Tx = parsed_args["grid-x"] * Rx +Ty = parsed_args["grid-y"] * Ry +Nz = parsed_args["grid-z"] + +Nx = Tx - 2H +Ny = Ty - 2H + +grid_type = Symbol(get(ENV, "grid_type", "simple_lat_lon")) +@info "Generating model (grid_type=$grid_type)..." +model = data_free_ocean_climate_model_init(arch, Nx, Ny, Nz; halo=(H, H, H), grid_type, + set_initial_conditions=false) +@show model + +GC.gc(true); GC.gc(false); GC.gc(true) + +TRY_COMPILE_FAILED[] = false +Ninner = ConcreteRNumber(2) + +for optimize in (:before_raise, false, :before_jit), code_type in (:hlo, :xla) + # We only want the optimised XLA code + optimize in (:before_raise, false) && code_type === :xla && continue + kernel_type = optimize === :before_raise ? "before_raise" : (optimize === false ? "unoptimised" : "optimised") + @info "Compiling $(kernel_type) $(code_type) kernels..." + if code_type === :hlo + first_code = try_compile_code() do + @code_hlo optimize=optimize raise=true shardy_passes=:post_sdy_propagation first_time_step!(model) + end + loop_code = try_compile_code() do + @code_hlo optimize=optimize raise=true shardy_passes=:post_sdy_propagation loop!(model, Ninner) + end + elseif code_type === :xla + first_code = try_compile_code() do + @code_xla raise=true first_time_step!(model) + end + loop_code = try_compile_code() do + @code_xla raise=true loop!(model, Ninner) + end + end + for name in ("first", "loop"), debug in (true, false) + # No debug info for `@code_xla` + code_type === :xla && debug && continue + open("$(kernel_type)_sharded_ocean_climate_simulation_$(name)$(debug ? "_debug" : "").$(code_type == :xla ? "xla" : "mlir")", "w") do io + show(IOContext(io, :debug => debug), (Base.@locals())[Symbol(name, "_code")]) + end + end +end + +if TRY_COMPILE_FAILED[] + error("compilation failed") +end diff --git a/sharding/sharded_ocean_climate_simulation_run.jl b/sharding/sharded_ocean_climate_simulation_run.jl new file mode 100644 index 00000000..3371108a --- /dev/null +++ b/sharding/sharded_ocean_climate_simulation_run.jl @@ -0,0 +1,159 @@ +using Dates +@info "This is when the fun begins" now(UTC) + +using ArgParse + +const args_settings = ArgParseSettings() +@add_arg_table! args_settings begin + "--grid-x" + help = "Base factor for number of grid points on the x axis." + default = 64 + arg_type = Int + "--grid-y" + help = "Base factor for number of grid points on the y axis." + default = 64 + arg_type = Int + "--grid-z" + help = "Base factor for number of grid points on the z axis." + default = 4 + arg_type = Int +end +const parsed_args = parse_args(ARGS, args_settings) + +ENV["JULIA_DEBUG"] = "Reactant_jll,Reactant" + +using GordonBell25 +using GordonBell25: first_time_step!, time_step!, loop!, factors, is_distributed_env_present +using GordonBell25: data_free_ocean_climate_model_init +using Oceananigans +using Oceananigans.Units +using Oceananigans.Architectures: ReactantState +using Random +using Printf +using Reactant + +if !is_distributed_env_present() + using MPI + MPI.Init() +end + +jobid_procid = GordonBell25.get_jobid_procid() + +# This must be called before `GordonBell25.initialize`! +GordonBell25.preamble() + +using Libdl: dllist +@show filter(contains("nccl"), dllist()) + +Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true +Reactant.MLIR.IR.DUMP_MLIR_DIR[] = joinpath(@__DIR__, "mlir_dumps", jobid_procid) +Reactant.Compiler.DEBUG_DISABLE_RESHARDING[] = true +Reactant.Compiler.WHILE_CONCAT[] = true + +GordonBell25.initialize(; single_gpu_per_process=false) + +devarch = Oceananigans.ReactantState() +arch = devarch + +Ndev = if arch isa Oceananigans.ReactantState + length(Reactant.devices()) +else + comm = MPI.COMM_WORLD + MPI.Comm_size(comm) +end + +@show Ndev + +Rx, Ry = factors(Ndev) +if Ndev == 1 + rank = 0 +else + arch = Oceananigans.Distributed( + arch; + partition = Partition(Rx, Ry, 1) + ) + rank = if devarch isa Oceananigans.ReactantState + Reactant.Distributed.local_rank() + else + comm = MPI.COMM_WORLD + MPI.Comm_rank(comm) + end +end + +@info "[$rank] allocations" GordonBell25.allocatorstats() +H = 8 +Tx = parsed_args["grid-x"] * Rx +Ty = parsed_args["grid-y"] * Ry +Nz = parsed_args["grid-z"] + +Nx = Tx - 2H +Ny = Ty - 2H + +grid_type = Symbol(get(ENV, "grid_type", "simple_lat_lon")) +@info "[$rank] Generating model (Nx=$Nx, Ny=$Ny, grid_type=$grid_type)..." now(UTC) +model = data_free_ocean_climate_model_init(arch, Nx, Ny, Nz; halo=(H, H, H), grid_type, + set_initial_conditions=false) +@info "[$rank] allocations" GordonBell25.allocatorstats() + +@show model + +Ninner = 256 + +if devarch isa Oceananigans.ReactantState + Ninner = if Ndev == 1 + ConcreteRNumber(Ninner) + else + ConcreteRNumber(Ninner; sharding=Sharding.NamedSharding(arch.connectivity, ())) + end +end + +@info "[$rank] Compiling first_time_step!..." now(UTC) +compile_options = CompileOptions(; sync=true, raise=true, strip_llvm_debuginfo=true, strip=["enzymexla.kernel_call", "(::Reactant.Compiler.LLVMFunc", "ka_with_reactant", "(::KernelAbstractions.Kernel", "var\"#_launch!;_launch!"]) +rfirst! = if devarch isa Oceananigans.ReactantState + @compile compile_options=compile_options first_time_step!(model) +else + first_time_step! +end + +@info "[$rank] allocations" GordonBell25.allocatorstats() +@info "[$rank] Compiling loop..." now(UTC) + +compiled_loop! = if devarch isa Oceananigans.ReactantState + @compile compile_options=compile_options loop!(model, Ninner) +else + loop! +end + +@info "[$rank] allocations" GordonBell25.allocatorstats() + +profile_dir = joinpath(@__DIR__, "profiling", jobid_procid) +mkpath(joinpath(profile_dir, "first_time_step")) +@info "[$rank] allocations" GordonBell25.allocatorstats() +@info "[$rank] Running first_time_step!..." now(UTC) +Reactant.with_profiler(joinpath(profile_dir, "first_time_step")) do + Reactant.Profiler.annotate("bench"; metadata=Dict("step_num" => 1, "_r" => 1)) do + @time "[$rank] first time step" rfirst!(model) + end +end +@info "[$rank] allocations" GordonBell25.allocatorstats() + +mkpath(joinpath(profile_dir, "loop")) +@info "[$rank] allocations" GordonBell25.allocatorstats() +@info "[$rank] running loop" now(UTC) +Reactant.with_profiler(joinpath(profile_dir, "loop")) do + Reactant.Profiler.annotate("bench"; metadata=Dict("step_num" => 1, "_r" => 1)) do + @time "[$rank] loop" compiled_loop!(model, Ninner) + end +end + +mkpath(joinpath(profile_dir, "loop2")) +@info "[$rank] allocations" GordonBell25.allocatorstats() +@info "[$rank] running second loop" now(UTC) +Reactant.with_profiler(joinpath(profile_dir, "loop2")) do + Reactant.Profiler.annotate("bench"; metadata=Dict("step_num" => 1, "_r" => 1)) do + @time "[$rank] second loop" compiled_loop!(model, Ninner) + end +end +@info "[$rank] allocations" GordonBell25.allocatorstats() + +@info "[$rank] Done!" now(UTC) diff --git a/src/data_free_ocean_climate_model.jl b/src/data_free_ocean_climate_model.jl index bd27d5c5..4c76f751 100644 --- a/src/data_free_ocean_climate_model.jl +++ b/src/data_free_ocean_climate_model.jl @@ -11,20 +11,40 @@ end function data_free_ocean_climate_model_init( arch::Architectures.AbstractArchitecture=Architectures.ReactantState(); - # Horizontal resolution resolution::Real = 2, # 1/4 for quarter degree - # Vertical resolution Nz::Int = 20, # eventually we want to increase this to between 100-600 ) + Nx, Ny = resolution_to_points(resolution) + return data_free_ocean_climate_model_init(arch, Nx, Ny, Nz) +end - grid = gaussian_islands_tripolar_grid(arch, resolution, Nz) +function data_free_ocean_climate_model_init( + arch::Architectures.AbstractArchitecture, + Nx::Int, Ny::Int, Nz::Int; + Δt = 30, + halo = (8, 8, 8), + grid_type = :simple_lat_lon, + set_initial_conditions = true, + free_surface = SplitExplicitFreeSurface(substeps=30), + ) + + grid = if grid_type === :gaussian_islands + gaussian_islands_tripolar_grid(arch, Nx, Ny, Nz; halo) + elseif grid_type === :simple_tripolar + simple_tripolar_grid(arch, Nx, Ny, Nz; halo) + elseif grid_type === :simple_lat_lon + simple_latitude_longitude_grid(arch, Nx, Ny, Nz; halo) + else + error("grid_type=$grid_type must be :gaussian_islands, :simple_tripolar, or :simple_lat_lon.") + end # See visualize_ocean_climate_simulation.jl for information about how to # visualize the results of this run. - Δt = 30seconds - free_surface = SplitExplicitFreeSurface(substeps=30) ocean = @gbprofile "ocean_simulation" ocean_simulation(grid; free_surface, Δt) - @gbprofile "set_ocean_model" set!(ocean.model, T=Tᵢ, S=Sᵢ) + + if set_initial_conditions + @gbprofile "set_ocean_model" set!(ocean.model, T=Tᵢ, S=Sᵢ) + end # Set up an atmosphere atmos_times = range(0, 1days, length=24) @@ -37,22 +57,24 @@ function data_free_ocean_climate_model_init( atmosphere = PrescribedAtmosphere(atmos_grid, atmos_times) - Ta = Field{Center, Center, Nothing}(atmos_grid) - ua = Field{Center, Center, Nothing}(atmos_grid) - Qs = Field{Center, Center, Nothing}(atmos_grid) + if set_initial_conditions + Ta = Field{Center, Center, Nothing}(atmos_grid) + ua = Field{Center, Center, Nothing}(atmos_grid) + Qs = Field{Center, Center, Nothing}(atmos_grid) - set!(Ta, Tatm) - set!(ua, zonal_wind) - set!(Qs, sunlight) + set!(Ta, Tatm) + set!(ua, zonal_wind) + set!(Qs, sunlight) - if arch isa Architectures.ReactantState - if Reactant.precompiling() - @code_hlo set_tracers(parent(atmosphere.tracers.T), parent(Ta), parent(atmosphere.velocities.u), parent(ua), parent(atmosphere.downwelling_radiation.shortwave), parent(Qs)) + if parent(atmosphere.tracers.T) isa Reactant.ConcreteRArray + if Reactant.precompiling() + @code_hlo set_tracers(parent(atmosphere.tracers.T), parent(Ta), parent(atmosphere.velocities.u), parent(ua), parent(atmosphere.downwelling_radiation.shortwave), parent(Qs)) + else + @jit set_tracers(parent(atmosphere.tracers.T), parent(Ta), parent(atmosphere.velocities.u), parent(ua), parent(atmosphere.downwelling_radiation.shortwave), parent(Qs)) + end else - @jit set_tracers(parent(atmosphere.tracers.T), parent(Ta), parent(atmosphere.velocities.u), parent(ua), parent(atmosphere.downwelling_radiation.shortwave), parent(Qs)) + set_tracers(parent(atmosphere.tracers.T), parent(Ta), parent(atmosphere.velocities.u), parent(ua), parent(atmosphere.downwelling_radiation.shortwave), parent(Qs)) end - else - set_tracers(parent(atmosphere.tracers.T), parent(Ta), parent(atmosphere.velocities.u), parent(ua), parent(atmosphere.downwelling_radiation.shortwave), parent(Qs)) end parent(atmosphere.tracers.q) .= 0 diff --git a/src/model_utils.jl b/src/model_utils.jl index 24423c48..ef87a483 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -123,6 +123,16 @@ function set_baroclinic_instability!(model) end end +function simple_tripolar_grid(arch::Architectures.AbstractArchitecture, resolution, Nz) + Nx, Ny = resolution_to_points(resolution) + return simple_tripolar_grid(arch, Nx, Ny, Nz) +end + +function simple_tripolar_grid(arch::Architectures.AbstractArchitecture, Nx, Ny, Nz; halo=(8, 8, 8)) + z = exponential_z_faces(; Nz, depth=4000, h=30) + return TripolarGrid(arch; size=(Nx, Ny, Nz), halo, z) +end + function gaussian_islands_tripolar_grid(arch::Architectures.AbstractArchitecture, resolution, Nz) Nx, Ny = resolution_to_points(resolution) return gaussian_islands_tripolar_grid(arch, Nx, Ny, Nz)