Skip to content

Commit 30cd92e

Browse files
committed
feat: fallback
1 parent dfc37a4 commit 30cd92e

File tree

6 files changed

+110
-12
lines changed

6 files changed

+110
-12
lines changed

src/Compiler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3763,8 +3763,8 @@ function compile_xla(
37633763

37643764
exec = XLA.compile(
37653765
client,
3766-
mod;
3767-
compile_options=xla_compile_options,
3766+
mod,
3767+
xla_compile_options;
37683768
num_outputs=length(mlir_fn_res.linear_results),
37693769
num_parameters=length(mlir_fn_res.linear_args),
37703770
mlir_fn_res.is_sharded,

src/xla/CompileOptions.jl

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ end
2121
function get_debug_options(; kwargs...)
2222
debug_options = get_default_debug_options()
2323

24-
# default overrides. can we changed by the user by passing in kwargs
24+
# default overrides. can be changed by the user by passing in kwargs
2525
debug_options.xla_gpu_cuda_data_dir = CUDA_DATA_DIR[]
2626
debug_options.xla_enable_enzyme_comms_opt = true
2727
debug_options.xla_gpu_experimental_use_raft_select_k = true
@@ -48,6 +48,13 @@ function get_debug_options(; kwargs...)
4848
return debug_options
4949
end
5050

51+
struct CompileOptionsWithoutProto
52+
device_id::Int64
53+
global_device_ids::Vector{Int64}
54+
use_shardy_partitioner::Bool
55+
use_spmd_partitioning::Bool
56+
end
57+
5158
function make_compile_options(;
5259
device_id::Int64,
5360
num_replicas::Int64=1,
@@ -57,14 +64,33 @@ function make_compile_options(;
5764
xla_executable_build_options=(;),
5865
xla_compile_options=(;),
5966
)
67+
if (
68+
isempty(xla_debug_options) &&
69+
(
70+
isempty(xla_executable_build_options) || (
71+
length(xla_executable_build_options) == 2 &&
72+
haskey(xla_executable_build_options, :use_shardy_partitioner) &&
73+
haskey(xla_executable_build_options, :use_spmd_partitioning)
74+
)
75+
) &&
76+
isempty(xla_compile_options)
77+
)
78+
return CompileOptionsWithoutProto(
79+
device_id,
80+
mesh_ids === nothing ? Int64[] : mesh_ids,
81+
get(xla_executable_build_options, :use_shardy_partitioner, false),
82+
get(xla_executable_build_options, :use_spmd_partitioning, false),
83+
)
84+
end
85+
6086
compile_options = get_default_compile_options()
6187
executable_build_options = compile_options.executable_build_options
6288

6389
executable_build_options.debug_options = get_debug_options(; xla_debug_options...)
6490
executable_build_options.num_replicas = num_replicas
6591
executable_build_options.num_partitions = num_partitions
6692

67-
# default overrides. can we changed by the user by passing in kwargs
93+
# default overrides. can be changed by the user by passing in kwargs
6894
executable_build_options.allow_spmd_sharding_propagation_to_parameters = [false]
6995
executable_build_options.allow_spmd_sharding_propagation_to_output = [false]
7096

src/xla/IFRT/Array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,8 @@ function replicate_array_to_all_devices(array::Array, sharding, mesh, size_arr)
370370

371371
exec = XLA.compile(
372372
XLA.client(array),
373-
mod;
374-
compile_options,
373+
mod,
374+
compile_options;
375375
num_outputs=1, # unused
376376
num_parameters=1, # unused
377377
is_sharded=true,

src/xla/IFRT/LoadedExecutable.jl

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ end
7373

7474
function XLA.compile(
7575
client::Client,
76-
mod::MLIR.IR.Module;
77-
compile_options::Reactant.Proto.xla.CompileOptionsProto,
76+
mod::MLIR.IR.Module,
77+
compile_options::Reactant.Proto.xla.CompileOptionsProto;
7878
num_parameters::Int64,
7979
num_outputs::Int64,
8080
is_sharded::Bool,
@@ -97,6 +97,42 @@ function XLA.compile(
9797
)
9898
end
9999

100+
function XLA.compile(
101+
client::Client,
102+
mod::MLIR.IR.Module,
103+
compile_options::Reactant.XLA.CompileOptionsWithoutProto;
104+
num_parameters::Int64,
105+
num_outputs::Int64,
106+
is_sharded::Bool,
107+
num_replicas::Int64,
108+
num_partitions::Int64,
109+
)
110+
GC.@preserve client mod begin
111+
exec = MLIR.IR.try_compile_dump_mlir(mod) do
112+
@ccall MLIR.API.mlir_c.ifrt_compile(
113+
client.client::Ptr{Cvoid},
114+
mod.module_::MLIR.API.MlirModule,
115+
compile_options.device_id::Clong,
116+
compile_options.global_device_ids::Ptr{Clong},
117+
length(compile_options.global_device_ids)::Clong,
118+
XLA.CUDA_DATA_DIR[]::Cstring,
119+
compile_options.use_shardy_partitioner::Bool,
120+
num_replicas::Int64,
121+
num_partitions::Int64,
122+
compile_options.use_spmd_partitioning::Bool,
123+
Reactant.PersistentCompileCache.kernel_cache_enabled()::Bool,
124+
Reactant.PersistentCompileCache.get_kernel_cache_path()::Cstring,
125+
Reactant.PersistentCompileCache.autotune_cache_enabled()::Bool,
126+
Reactant.PersistentCompileCache.get_autotune_cache_directory()::Cstring,
127+
Reactant.Distributed.local_rank()::Cint,
128+
)::Ptr{Cvoid}
129+
end
130+
end
131+
return LoadedExecutable(
132+
exec, num_outputs, num_parameters, is_sharded, num_replicas, num_partitions
133+
)
134+
end
135+
100136
@inline function XLA.execute(
101137
exec::LoadedExecutable,
102138
inputs::NTuple{N,Ptr{Cvoid}},

src/xla/PJRT/LoadedExecutable.jl

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ end
6767

6868
function XLA.compile(
6969
client::Client,
70-
mod::MLIR.IR.Module;
71-
compile_options::Reactant.Proto.xla.CompileOptionsProto,
70+
mod::MLIR.IR.Module,
71+
compile_options::Reactant.Proto.xla.CompileOptionsProto;
7272
num_parameters::Int64,
7373
num_outputs::Int64,
7474
is_sharded::Bool,
@@ -91,6 +91,42 @@ function XLA.compile(
9191
)
9292
end
9393

94+
function XLA.compile(
95+
client::Client,
96+
mod::MLIR.IR.Module,
97+
compile_options::Reactant.XLA.CompileOptionsWithoutProto;
98+
num_parameters::Int64,
99+
num_outputs::Int64,
100+
is_sharded::Bool,
101+
num_replicas::Int64,
102+
num_partitions::Int64,
103+
)
104+
GC.@preserve client mod begin
105+
exec = MLIR.IR.try_compile_dump_mlir(mod) do
106+
@ccall MLIR.API.mlir_c.ClientCompile(
107+
client.client::Ptr{Cvoid},
108+
mod.module_::MLIR.API.MlirModule,
109+
compile_options.device_id::Clong,
110+
compile_options.global_device_ids::Ptr{Clong},
111+
length(compile_options.global_device_ids)::Clong,
112+
XLA.CUDA_DATA_DIR[]::Cstring,
113+
compile_options.use_shardy_partitioner::Bool,
114+
num_replicas::Int64,
115+
num_partitions::Int64,
116+
compile_options.use_spmd_partitioning::Bool,
117+
Reactant.PersistentCompileCache.kernel_cache_enabled()::Bool,
118+
Reactant.PersistentCompileCache.get_kernel_cache_path()::Cstring,
119+
Reactant.PersistentCompileCache.autotune_cache_enabled()::Bool,
120+
Reactant.PersistentCompileCache.get_autotune_cache_directory()::Cstring,
121+
Reactant.Distributed.local_rank()::Cint,
122+
)::Ptr{Cvoid}
123+
end
124+
end
125+
return LoadedExecutable(
126+
exec, num_outputs, num_parameters, is_sharded, num_replicas, num_partitions
127+
)
128+
end
129+
94130
function execute_ir(N, M, n_outs, with_device::Bool, nmesh_ids::Int64)
95131
ptr = @static if VERSION < v"1.12"
96132
sizeof(Int) == sizeof(Int64) ? "i64" : "i32"

src/xla/XLA.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ function LLVMclopts(opts...)
2626
)::Cvoid
2727
end
2828

29+
include("CompileOptions.jl")
30+
2931
include("Distributed.jl")
3032
include("Client.jl")
3133
include("Device.jl")
@@ -46,8 +48,6 @@ include("PJRT/PJRT.jl")
4648

4749
include("IFRT/IFRT.jl")
4850

49-
include("CompileOptions.jl")
50-
5151
abstract type AbstractBackendState end
5252

5353
for runtime in (:PJRT, :IFRT)

0 commit comments

Comments
 (0)