Skip to content

Commit 1be04f7

Browse files
committed
Fix no inputs xla call
1 parent 666c9b6 commit 1be04f7

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

src/XLA.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,19 @@ end
285285
function execute_ir(N, n_outs, fn)
286286
ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32"
287287
cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32"
288-
res = """define { [$n_outs x $ptr], [$n_outs x $ptr], i8 } @f($ptr %exec, [$N x $ptr] %inps, [$N x i8] %donated) alwaysinline {
288+
args = N > 0 ? ", [$N x $ptr] %inps, [$N x i8] %donated" : ""
289+
stores = N > 0 ? """
290+
store [$N x $ptr] %inps, [$N x $ptr]* %inpa
291+
store [$N x i8] %donated, [$N x i8]* %dona
292+
""" : ""
293+
294+
res = """define { [$n_outs x $ptr], [$n_outs x $ptr], i8 } @f($ptr %exec $args) alwaysinline {
289295
entry:
290296
%inpa = alloca [$N x $ptr]
297+
%dona = alloca [$N x i8]
291298
%outa = alloca [$n_outs x $ptr]
292299
%futpa = alloca [$n_outs x $ptr]
293-
store [$N x $ptr] %inps, [$N x $ptr]* %inpa
294-
%dona = alloca [$N x i8]
295-
store [$N x i8] %donated, [$N x i8]* %dona
300+
$stores
296301
%futa = alloca i8
297302
call void inttoptr ($ptr $fn to void ($ptr, $cint, [$N x $ptr]*, [$N x i8]*, $cint, [$n_outs x $ptr]*, i8*, [$n_outs x $ptr]*)*)($ptr %exec, $cint $N, [$N x $ptr]* nocapture readonly %inpa, [$N x i8]* nocapture readonly %dona, $cint $n_outs, [$n_outs x $ptr]* nocapture writeonly %outa, i8* nocapture writeonly %futa, [$n_outs x $ptr]* nocapture writeonly %futpa)
298303
%out = load [$n_outs x $ptr], [$n_outs x $ptr]* %outa
@@ -323,17 +328,19 @@ end
323328
:(AsyncBuffer(Buffer(outputs[$i]), future ? Future(future_res[$i]) : nothing)),
324329
)
325330
end
331+
332+
args_type = N > 0 ? (Ptr{Cvoid}, NTuple{N,Ptr{Cvoid}}, NTuple{N,UInt8}) : (Ptr{Cvoid},)
333+
args = N > 0 ? (:inputs, :donated_args) : ()
326334
return quote
327335
Base.@_inline_meta
328336
exec = exec.exec
329337
GC.@preserve exec begin
330338
outputs, future_res, future = Base.llvmcall(
331339
($ir, "f"),
332340
Tuple{NTuple{n_outs,Ptr{Cvoid}},NTuple{n_outs,Ptr{Cvoid}},Bool},
333-
Tuple{Ptr{Cvoid},NTuple{N,Ptr{Cvoid}},NTuple{N,UInt8}},
341+
Tuple{$args_type...},
334342
exec,
335-
inputs,
336-
donated_args,
343+
$(args...),
337344
)
338345
end
339346
return ($(results...),)

test/ops.jl

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,12 @@ end
140140
end
141141

142142
@testset "constant" begin
143-
# TODO currently crashes due to #196
144-
# for x in [[1, 2, 3], [1.1, 2.2, 3.3], [1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im]]
145-
# @test x ≈ @jit Ops.constant(x)
143+
for x in [[1, 2, 3], [1.1, 2.2, 3.3], [1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im]]
144+
@test x @jit Ops.constant(x)
146145

147-
# xscalar = x[1]
148-
# @test xscalar ≈ @jit Ops.constant(xscalar)
149-
# end
146+
xscalar = x[1]
147+
@test xscalar @jit Ops.constant(xscalar)
148+
end
150149
end
151150

152151
@testset "cosine" begin
@@ -281,22 +280,21 @@ end
281280
end
282281

283282
@testset "iota" begin
284-
# TODO this crashes. seems like the same error as #196
285-
# g1(shape) = Ops.iota(Int, shape; iota_dimension=1)
286-
# @test [
287-
# 0 0 0 0 0
288-
# 1 1 1 1 1
289-
# 2 2 2 2 2
290-
# 3 3 3 3 3
291-
# ] ≈ @jit g1([4, 5])
292-
293-
# g2(shape) = Ops.iota(Int, shape; iota_dimension=2)
294-
# @test [
295-
# 0 1 2 3 4
296-
# 0 1 2 3 4
297-
# 0 1 2 3 4
298-
# 0 1 2 3 4
299-
# ] ≈ @jit g2([4, 5])
283+
g1(shape) = Ops.iota(Int, shape; iota_dimension=1)
284+
@test [
285+
0 0 0 0 0
286+
1 1 1 1 1
287+
2 2 2 2 2
288+
3 3 3 3 3
289+
] @jit g1([4, 5])
290+
291+
g2(shape) = Ops.iota(Int, shape; iota_dimension=2)
292+
@test [
293+
0 1 2 3 4
294+
0 1 2 3 4
295+
0 1 2 3 4
296+
0 1 2 3 4
297+
] @jit g2([4, 5])
300298
end
301299

302300
@testset "is_finite" begin
@@ -443,8 +441,7 @@ end
443441
end
444442

445443
@testset "partition_id" begin
446-
# TODO this crashes. seems like the same error as #196
447-
# @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32}
444+
@test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32}
448445
end
449446

450447
@testset "popcnt" begin
@@ -481,8 +478,7 @@ end
481478
end
482479

483480
@testset "replica_id" begin
484-
# TODO this crashes. seems like the same error as #196
485-
# @test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32}
481+
@test @jit(Ops.partition_id()) isa ConcreteRNumber{UInt32}
486482
end
487483

488484
@testset "reshape" begin

0 commit comments

Comments
 (0)