Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .github/workflows/Compile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- '.github/workflows/Compile.yml'
- '.github/workflows/CompileOrRun.yml'
- 'sharding/sharded_baroclinic_instability_simulation_compile.jl'
- 'sharding/sharded_ocean_climate_simulation_compile.jl'
- 'simulations/baroclinic_instability_simulation_compile.jl'
- 'simulations/ocean_climate_simulation_compile.jl'
- 'Project.toml'
Expand All @@ -19,6 +20,7 @@ on:
- '.github/workflows/Compile.yml'
- '.github/workflows/CompileOrRun.yml'
- 'sharding/sharded_baroclinic_instability_simulation_compile.jl'
- 'sharding/sharded_ocean_climate_simulation_compile.jl'
- 'simulations/baroclinic_instability_simulation_compile.jl'
- 'simulations/ocean_climate_simulation_compile.jl'
- 'Project.toml'
Expand Down Expand Up @@ -63,16 +65,16 @@ jobs:
julia_optlevel: 0

compile_sharded:
name: Sharded - Julia ${{ matrix.julia_version }} - ${{ matrix.grid_type }} - ${{ matrix.os }}
name: Sharded - Julia 1.11 - ${{ matrix.sim_type }} - ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ['ubuntu-24.04']
grid_type: ['simple_lat_lon']
sim_type: ['sharded_baroclinic_instability', 'sharded_ocean_climate']
uses: ./.github/workflows/CompileOrRun.yml
with:
sim_type: 'sharded_baroclinic_instability'
grid_type: ${{ matrix.grid_type }}
sim_type: ${{ matrix.sim_type }}
grid_type: ${{ matrix.sim_type == 'sharded_baroclinic_instability' && 'simple_lat_lon' || 'simple_tripolar' }}
sharded: true
run_dir: 'sharding'
julia_version: '1.11'
Expand Down
10 changes: 6 additions & 4 deletions .github/workflows/Run.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- '.github/workflows/Run.yml'
- '.github/workflows/CompileOrRun.yml'
- 'sharding/sharded_baroclinic_instability_simulation_run.jl'
- 'sharding/sharded_ocean_climate_simulation_run.jl'
- 'simulations/baroclinic_instability_simulation_run.jl'
- 'simulations/ocean_climate_simulation_run.jl'
- 'Project.toml'
Expand All @@ -19,6 +20,7 @@ on:
- '.github/workflows/Run.yml'
- '.github/workflows/CompileOrRun.yml'
- 'sharding/sharded_baroclinic_instability_simulation_run.jl'
- 'sharding/sharded_ocean_climate_simulation_run.jl'
- 'simulations/baroclinic_instability_simulation_run.jl'
- 'simulations/ocean_climate_simulation_run.jl'
- 'Project.toml'
Expand Down Expand Up @@ -61,16 +63,16 @@ jobs:
julia_optlevel: 0

run_sharded:
name: Sharded - Julia ${{ matrix.julia_version }} - ${{ matrix.grid_type }} - ${{ matrix.os }}
name: Sharded - Julia 1.11 - ${{ matrix.sim_type }} - ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ['ubuntu-24.04', 'ubuntu-22.04-arm']
grid_type: ['simple_lat_lon']
sim_type: ['sharded_baroclinic_instability', 'sharded_ocean_climate']
uses: ./.github/workflows/CompileOrRun.yml
with:
sim_type: 'sharded_baroclinic_instability'
grid_type: ${{ matrix.grid_type }}
sim_type: ${{ matrix.sim_type }}
grid_type: ${{ matrix.sim_type == 'sharded_baroclinic_instability' && 'simple_lat_lon' || 'simple_tripolar' }}
sharded: true
run_dir: 'sharding'
julia_version: '1.11'
Expand Down
96 changes: 96 additions & 0 deletions sharding/sharded_ocean_climate_simulation_compile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
using ArgParse

const args_settings = ArgParseSettings()
@add_arg_table! args_settings begin
"--grid-x"
help = "Base factor for number of grid points on the x axis."
default = 64
arg_type = Int
"--grid-y"
help = "Base factor for number of grid points on the y axis."
default = 64
arg_type = Int
"--grid-z"
help = "Base factor for number of grid points on the z axis."
default = 4
arg_type = Int
end
const parsed_args = parse_args(ARGS, args_settings)

using GordonBell25: first_time_step!, loop!, try_compile_code, preamble, TRY_COMPILE_FAILED
using GordonBell25: data_free_ocean_climate_model_init, PROFILE, GordonBell25
using Reactant
using Oceananigans
using Oceananigans.Architectures: ReactantState
Reactant.Compiler.WHILE_CONCAT[] = true

PROFILE[] = true

preamble()

GordonBell25.initialize(; single_gpu_per_process=false)
@show Ndev = length(Reactant.devices())

Rx, Ry = GordonBell25.factors(Ndev)
if Ndev == 1
rank = 0
arch = Oceananigans.ReactantState()
else
arch = Oceananigans.Distributed(
Oceananigans.ReactantState();
partition = Partition(Rx, Ry, 1)
)
rank = Reactant.Distributed.local_rank()
end

H = 8
Tx = parsed_args["grid-x"] * Rx
Ty = parsed_args["grid-y"] * Ry
Nz = parsed_args["grid-z"]

Nx = Tx - 2H
Ny = Ty - 2H

grid_type = Symbol(get(ENV, "grid_type", "simple_lat_lon"))
@info "Generating model (grid_type=$grid_type)..."
model = data_free_ocean_climate_model_init(arch, Nx, Ny, Nz; halo=(H, H, H), grid_type,
set_initial_conditions=false)
@show model

GC.gc(true); GC.gc(false); GC.gc(true)

TRY_COMPILE_FAILED[] = false
Ninner = ConcreteRNumber(2)

for optimize in (:before_raise, false, :before_jit), code_type in (:hlo, :xla)
# We only want the optimised XLA code
optimize in (:before_raise, false) && code_type === :xla && continue
kernel_type = optimize === :before_raise ? "before_raise" : (optimize === false ? "unoptimised" : "optimised")
@info "Compiling $(kernel_type) $(code_type) kernels..."
if code_type === :hlo
first_code = try_compile_code() do
@code_hlo optimize=optimize raise=true shardy_passes=:post_sdy_propagation first_time_step!(model)
end
loop_code = try_compile_code() do
@code_hlo optimize=optimize raise=true shardy_passes=:post_sdy_propagation loop!(model, Ninner)
end
elseif code_type === :xla
first_code = try_compile_code() do
@code_xla raise=true first_time_step!(model)
end
loop_code = try_compile_code() do
@code_xla raise=true loop!(model, Ninner)
end
end
for name in ("first", "loop"), debug in (true, false)
# No debug info for `@code_xla`
code_type === :xla && debug && continue
open("$(kernel_type)_sharded_ocean_climate_simulation_$(name)$(debug ? "_debug" : "").$(code_type == :xla ? "xla" : "mlir")", "w") do io
show(IOContext(io, :debug => debug), (Base.@locals())[Symbol(name, "_code")])
end
end
end

if TRY_COMPILE_FAILED[]
error("compilation failed")
end
159 changes: 159 additions & 0 deletions sharding/sharded_ocean_climate_simulation_run.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
using Dates
@info "This is when the fun begins" now(UTC)

using ArgParse

const args_settings = ArgParseSettings()
@add_arg_table! args_settings begin
"--grid-x"
help = "Base factor for number of grid points on the x axis."
default = 64
arg_type = Int
"--grid-y"
help = "Base factor for number of grid points on the y axis."
default = 64
arg_type = Int
"--grid-z"
help = "Base factor for number of grid points on the z axis."
default = 4
arg_type = Int
end
const parsed_args = parse_args(ARGS, args_settings)

ENV["JULIA_DEBUG"] = "Reactant_jll,Reactant"

using GordonBell25
using GordonBell25: first_time_step!, time_step!, loop!, factors, is_distributed_env_present
using GordonBell25: data_free_ocean_climate_model_init
using Oceananigans
using Oceananigans.Units
using Oceananigans.Architectures: ReactantState
using Random
using Printf
using Reactant

if !is_distributed_env_present()
using MPI
MPI.Init()
end

jobid_procid = GordonBell25.get_jobid_procid()

# This must be called before `GordonBell25.initialize`!
GordonBell25.preamble()

using Libdl: dllist
@show filter(contains("nccl"), dllist())

Reactant.MLIR.IR.DUMP_MLIR_ALWAYS[] = true
Reactant.MLIR.IR.DUMP_MLIR_DIR[] = joinpath(@__DIR__, "mlir_dumps", jobid_procid)
Reactant.Compiler.DEBUG_DISABLE_RESHARDING[] = true
Reactant.Compiler.WHILE_CONCAT[] = true

GordonBell25.initialize(; single_gpu_per_process=false)

devarch = Oceananigans.ReactantState()
arch = devarch

Ndev = if arch isa Oceananigans.ReactantState
length(Reactant.devices())
else
comm = MPI.COMM_WORLD
MPI.Comm_size(comm)
end

@show Ndev

Rx, Ry = factors(Ndev)
if Ndev == 1
rank = 0
else
arch = Oceananigans.Distributed(
arch;
partition = Partition(Rx, Ry, 1)
)
rank = if devarch isa Oceananigans.ReactantState
Reactant.Distributed.local_rank()
else
comm = MPI.COMM_WORLD
MPI.Comm_rank(comm)
end
end

@info "[$rank] allocations" GordonBell25.allocatorstats()
H = 8
Tx = parsed_args["grid-x"] * Rx
Ty = parsed_args["grid-y"] * Ry
Nz = parsed_args["grid-z"]

Nx = Tx - 2H
Ny = Ty - 2H

grid_type = Symbol(get(ENV, "grid_type", "simple_lat_lon"))
@info "[$rank] Generating model (Nx=$Nx, Ny=$Ny, grid_type=$grid_type)..." now(UTC)
model = data_free_ocean_climate_model_init(arch, Nx, Ny, Nz; halo=(H, H, H), grid_type,
set_initial_conditions=false)
@info "[$rank] allocations" GordonBell25.allocatorstats()

@show model

Ninner = 256

if devarch isa Oceananigans.ReactantState
Ninner = if Ndev == 1
ConcreteRNumber(Ninner)
else
ConcreteRNumber(Ninner; sharding=Sharding.NamedSharding(arch.connectivity, ()))
end
end

@info "[$rank] Compiling first_time_step!..." now(UTC)
compile_options = CompileOptions(; sync=true, raise=true, strip_llvm_debuginfo=true, strip=["enzymexla.kernel_call", "(::Reactant.Compiler.LLVMFunc", "ka_with_reactant", "(::KernelAbstractions.Kernel", "var\"#_launch!;_launch!"])
rfirst! = if devarch isa Oceananigans.ReactantState
@compile compile_options=compile_options first_time_step!(model)
else
first_time_step!
end

@info "[$rank] allocations" GordonBell25.allocatorstats()
@info "[$rank] Compiling loop..." now(UTC)

compiled_loop! = if devarch isa Oceananigans.ReactantState
@compile compile_options=compile_options loop!(model, Ninner)
else
loop!
end

@info "[$rank] allocations" GordonBell25.allocatorstats()

profile_dir = joinpath(@__DIR__, "profiling", jobid_procid)
mkpath(joinpath(profile_dir, "first_time_step"))
@info "[$rank] allocations" GordonBell25.allocatorstats()
@info "[$rank] Running first_time_step!..." now(UTC)
Reactant.with_profiler(joinpath(profile_dir, "first_time_step")) do
Reactant.Profiler.annotate("bench"; metadata=Dict("step_num" => 1, "_r" => 1)) do
@time "[$rank] first time step" rfirst!(model)
end
end
@info "[$rank] allocations" GordonBell25.allocatorstats()

mkpath(joinpath(profile_dir, "loop"))
@info "[$rank] allocations" GordonBell25.allocatorstats()
@info "[$rank] running loop" now(UTC)
Reactant.with_profiler(joinpath(profile_dir, "loop")) do
Reactant.Profiler.annotate("bench"; metadata=Dict("step_num" => 1, "_r" => 1)) do
@time "[$rank] loop" compiled_loop!(model, Ninner)
end
end

mkpath(joinpath(profile_dir, "loop2"))
@info "[$rank] allocations" GordonBell25.allocatorstats()
@info "[$rank] running second loop" now(UTC)
Reactant.with_profiler(joinpath(profile_dir, "loop2")) do
Reactant.Profiler.annotate("bench"; metadata=Dict("step_num" => 1, "_r" => 1)) do
@time "[$rank] second loop" compiled_loop!(model, Ninner)
end
end
@info "[$rank] allocations" GordonBell25.allocatorstats()

@info "[$rank] Done!" now(UTC)
Loading
Loading