Skip to content

Commit a3be565

Browse files
authored
feat: define a @opcall macro for better debug info (#1574)
* feat: define a @opcall macro for better debug info * fix: missing location in Ops.select * chore: run formatter * fix: missing location in Ops.clamp
1 parent 7e21eca commit a3be565

22 files changed

+573
-396
lines changed

ext/ReactantAbstractFFTsExt.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ReactantAbstractFFTsExt
22

33
using AbstractFFTs: AbstractFFTs
44
using Reactant: Reactant, MLIR, Ops, AnyTracedRArray, TracedRArray, TracedUtils
5+
using Reactant.Ops: @opcall
56

67
function __permutation_to_move_dims_to_end(dims, N::Integer)
78
perm = [i for i in 1:N if i Set(dims)]
@@ -21,18 +22,20 @@ for op in (:rfft, :fft, :ifft)
2122

2223
fft_lengths = Int64[size(x, dim) for dim in reverse(dims)]
2324
if __is_valid_stablehlo_fft_dims(dims, ndims(x))
24-
return Ops.fft(
25+
return @opcall fft(
2526
TracedUtils.materialize_traced_array(x);
2627
type=$(uppercase(string(op))),
2728
length=fft_lengths,
2829
)
2930
end
3031
perm = __permutation_to_move_dims_to_end(dims, ndims(x))
3132
return permutedims(
32-
Ops.fft(
33-
TracedUtils.materialize_traced_array(permutedims(x, perm));
34-
type=$(uppercase(string(op))),
35-
length=fft_lengths,
33+
@opcall(
34+
fft(
35+
TracedUtils.materialize_traced_array(permutedims(x, perm));
36+
type=$(uppercase(string(op))),
37+
length=fft_lengths,
38+
)
3639
),
3740
invperm(perm),
3841
)
@@ -48,7 +51,7 @@ for op in (:irfft,)
4851
fft_lengths = vcat(Int64[size(x, dim) for dim in reverse(dims[2:end])], d)
4952

5053
if __is_valid_stablehlo_fft_dims(dims, ndims(x))
51-
return Ops.fft(
54+
return @opcall fft(
5255
TracedUtils.materialize_traced_array(x);
5356
type=$(uppercase(string(op))),
5457
length=fft_lengths,
@@ -57,10 +60,12 @@ for op in (:irfft,)
5760

5861
perm = __permutation_to_move_dims_to_end(dims, ndims(x))
5962
return permutedims(
60-
Ops.fft(
61-
TracedUtils.materialize_traced_array(permutedims(x, perm));
62-
type=$(uppercase(string(op))),
63-
length=fft_lengths,
63+
@opcall(
64+
fft(
65+
TracedUtils.materialize_traced_array(permutedims(x, perm));
66+
type=$(uppercase(string(op))),
67+
length=fft_lengths,
68+
)
6469
),
6570
invperm(perm),
6671
)

ext/ReactantCUDAExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import KernelAbstractions as KA
1010
using LLVM: LLVM
1111
using Libdl
1212

13+
using Reactant.Ops: @opcall
14+
1315
const ReactantKernelAbstractionsExt = Base.get_extension(
1416
Reactant, :ReactantKernelAbstractionsExt
1517
)
@@ -469,7 +471,7 @@ function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::DenseCuArray)
469471
return Adapt.adapt_storage(ka, Array(xs))
470472
end
471473
function Adapt.adapt_storage(ka::ReactantKernelAdaptor, xs::Array)
472-
return Adapt.adapt_storage(ka, Reactant.Ops.constant(xs))
474+
return Adapt.adapt_storage(ka, @opcall(constant(xs)))
473475
end
474476
function Adapt.adapt_structure(
475477
to::ReactantKernelAdaptor, bc::Broadcast.Broadcasted{Style,<:Any,Type{T}}

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ for (jlop, hloop) in (
33
(:(NNlib.sigmoid_fast), :logistic),
44
(:(NNlib.sigmoid), :logistic),
55
)
6-
@eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x)
6+
@eval $(jlop)(x::TracedRNumber) = @opcall $(hloop)(x)
77
end
88

99
function NNlib.softmax!(out::AnyTracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
@@ -89,7 +89,7 @@ function overloaded_conv!(
8989

9090
weight = W
9191
if !flipkernel
92-
weight = Reactant.Ops.reverse(weight; dimensions=kernel_spatial_dims)
92+
weight = @opcall reverse(weight; dimensions=kernel_spatial_dims)
9393
end
9494

9595
conv = Reactant.MLIR.Dialects.stablehlo.convolution(
@@ -211,7 +211,7 @@ function overloaded_∇conv_filter!(
211211

212212
if !NNlib.flipkernel(cdims)
213213
set_mlir_data!(
214-
dw, get_mlir_data(Reactant.Ops.reverse(dw; dimensions=output_spatial_dims))
214+
dw, get_mlir_data(@opcall(reverse(dw; dimensions=output_spatial_dims)))
215215
)
216216
end
217217

@@ -312,7 +312,7 @@ function overloaded_∇conv_data!(
312312
)
313313

314314
if NNlib.flipkernel(cdims)
315-
w = Reactant.Ops.reverse(w; dimensions=kernel_spatial_dims)
315+
w = @opcall reverse(w; dimensions=kernel_spatial_dims)
316316
end
317317

318318
conv = MLIR.Dialects.stablehlo.convolution(
@@ -411,7 +411,7 @@ function NNlib.batched_mul!(
411411
end
412412
end
413413

414-
tmp = Ops.dot_general(
414+
tmp = @opcall dot_general(
415415
T1.(materialize_traced_array(x)),
416416
T1.(materialize_traced_array(y));
417417
contracting_dimensions=([3], [2]),
@@ -430,7 +430,7 @@ function NNlib.pad_constant(
430430
low = [i[1] for i in pad]
431431
high = [i[2] for i in pad]
432432
interior = [0 for i in pad]
433-
return Ops.pad(materialize_traced_array(x), value; low, high, interior)
433+
return @opcall pad(materialize_traced_array(x), value; low, high, interior)
434434
end
435435

436436
# Gather
@@ -462,7 +462,7 @@ end
462462
function _nnlib_gather_impl(src::AnyTracedRArray, idxs::AbstractArray, n_dims::Int)
463463
idxs = TracedUtils.promote_to(TracedRArray{Int,ndims(idxs)}, idxs)
464464
n_idxs = size(idxs, 1)
465-
return Ops.gather(
465+
return @opcall gather(
466466
src,
467467
idxs;
468468
offset_dims=collect(Int64, 1:n_dims),
@@ -506,7 +506,7 @@ function NNlib.scatter(
506506
)
507507
end
508508
xinit = isnothing(init) ? NNlib.scatter_empty(op, T) : init
509-
dst = Ops.fill(xinit, dstsz)
509+
dst = @opcall fill(xinit, dstsz)
510510

511511
NNlib.scatter!(op, dst, src, idx)
512512
return dst
@@ -551,17 +551,21 @@ function _nnlib_scatter_impl(
551551
) where {OP,T}
552552
scatter_indices = TracedUtils.promote_to(TracedRArray{Int,ndims(idx)}, idx)
553553
n_idxs = size(scatter_indices, 1)
554-
return Ops.scatter(
555-
op,
556-
[dst],
557-
scatter_indices,
558-
[src];
559-
update_window_dims=collect(Int64, 1:n_dims),
560-
inserted_window_dims=collect(Int64, (n_dims + 1):ndims(dst)),
561-
input_batching_dims=Int64[],
562-
scatter_indices_batching_dims=Int64[],
563-
scatter_dims_to_operand_dims=collect(Int64, (ndims(dst) - n_idxs + 1):ndims(dst)),
564-
index_vector_dim=Int64(1),
554+
return @opcall(
555+
scatter(
556+
op,
557+
[dst],
558+
scatter_indices,
559+
[src];
560+
update_window_dims=collect(Int64, 1:n_dims),
561+
inserted_window_dims=collect(Int64, (n_dims + 1):ndims(dst)),
562+
input_batching_dims=Int64[],
563+
scatter_indices_batching_dims=Int64[],
564+
scatter_dims_to_operand_dims=collect(
565+
Int64, (ndims(dst) - n_idxs + 1):ndims(dst)
566+
),
567+
index_vector_dim=Int64(1),
568+
)
565569
)[1]
566570
end
567571

ext/ReactantNNlibExt/Ops.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,19 @@ function reduce_window(
1212

1313
padding = collect(Int64, reshape([padding..., 0, 0, 0, 0], (2, N))')
1414

15-
return Ops.reduce_window(
16-
f,
17-
[materialize_traced_array(x)],
18-
[Ops.constant(T(init))];
19-
window_dimensions=[kernel_size..., 1, 1],
20-
window_strides=[stride..., 1, 1],
21-
window_dilations=[dilation..., 1, 1],
22-
padding_low=padding[:, 1],
23-
padding_high=padding[:, 2],
24-
output_shape=Int[output_spatial_shapes..., size(x, N - 1), size(x, N)],
25-
base_dilations=ones(Int, N),
15+
return @opcall(
16+
reduce_window(
17+
f,
18+
[materialize_traced_array(x)],
19+
[@opcall(constant(T(init)))];
20+
window_dimensions=[kernel_size..., 1, 1],
21+
window_strides=[stride..., 1, 1],
22+
window_dilations=[dilation..., 1, 1],
23+
padding_low=padding[:, 1],
24+
padding_high=padding[:, 2],
25+
output_shape=Int[output_spatial_shapes..., size(x, N - 1), size(x, N)],
26+
base_dilations=ones(Int, N),
27+
)
2628
)[1]
2729
end
2830

@@ -31,7 +33,7 @@ function upsample_linear(
3133
) where {T}
3234
W, _, _ = size(x)
3335

34-
out_idxs = Ops.iota(Int32, [out_size[1]]; iota_dimension=1)
36+
out_idxs = @opcall iota(Int32, [out_size[1]]; iota_dimension=1)
3537
iw0, iw1, w0_λ, w1_λ = source_idx_and_λ(rwidth, out_idxs, align_corners, W)
3638

3739
x0 = x[iw0, :, :]
@@ -45,8 +47,8 @@ function upsample_linear(
4547
) where {T}
4648
W, H, _, _ = size(x)
4749

48-
out_width = Ops.iota(Int32, [out_size[1]]; iota_dimension=1)
49-
out_height = Ops.iota(Int32, [out_size[2]]; iota_dimension=1)
50+
out_width = @opcall iota(Int32, [out_size[1]]; iota_dimension=1)
51+
out_height = @opcall iota(Int32, [out_size[2]]; iota_dimension=1)
5052

5153
iw0, iw1, w0_λ, w1_λ = source_idx_and_λ(rwidth, out_width, align_corners, W)
5254
ih0, ih1, h0_λ, h1_λ = source_idx_and_λ(rheight, out_height, align_corners, H)
@@ -72,9 +74,9 @@ function upsample_linear(
7274
) where {T}
7375
W, H, D, _, _ = size(x)
7476

75-
out_width = Ops.iota(Int32, [out_size[1]]; iota_dimension=1)
76-
out_height = Ops.iota(Int32, [out_size[2]]; iota_dimension=1)
77-
out_depth = Ops.iota(Int32, [out_size[3]]; iota_dimension=1)
77+
out_width = @opcall iota(Int32, [out_size[1]]; iota_dimension=1)
78+
out_height = @opcall iota(Int32, [out_size[2]]; iota_dimension=1)
79+
out_depth = @opcall iota(Int32, [out_size[3]]; iota_dimension=1)
7880

7981
iw0, iw1, w0_λ, w1_λ = source_idx_and_λ(rwidth, out_width, align_corners, W)
8082
ih0, ih1, h0_λ, h1_λ = source_idx_and_λ(rheight, out_height, align_corners, H)

ext/ReactantNNlibExt/ReactantNNlibExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using Reactant:
77

88
using Reactant.TracedUtils:
99
TracedUtils, materialize_traced_array, get_mlir_data, set_mlir_data!
10+
using Reactant.Ops: @opcall
1011

1112
using ReactantCore: @trace
1213
using LinearAlgebra: LinearAlgebra, triu

ext/ReactantOneHotArraysExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ReactantOneHotArraysExt
33
using OneHotArrays
44
using Reactant
55
using Reactant: TracedRArray, TracedRNumber, TracedUtils, Ops
6+
using Reactant.Ops: @opcall
67

78
function Reactant.traced_type_inner(
89
@nospecialize(_::Type{OneHotArrays.OneHotArray{T,N,Np1,I}}),
@@ -35,9 +36,9 @@ function TracedUtils.materialize_traced_array(r::OneHotArrays.OneHotArray)
3536

3637
linear_indices =
3738
TracedUtils.promote_to(TracedRArray{Int64,ndims(r.indices)}, indices) .+
38-
Ops.iota(Int64, [B]; iota_dimension=1) .* N
39+
@opcall(iota(Int64, [B]; iota_dimension=1)) .* N
3940

40-
z = Ops.fill(false, (N, B))
41+
z = @opcall(fill(false, (N, B)))
4142
z[linear_indices] = fill(true, length(linear_indices))
4243
return reshape(z, size(r))
4344
end

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ReactantPythonCallExt
22

33
using PythonCall
44
using Reactant: Reactant, TracedRArray
5+
using Reactant.Ops: @opcall
56

67
const jaxptr = Ref{Py}()
78
const jnpptr = Ref{Py}()

ext/ReactantPythonCallExt/pycall.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function PythonCall.pycall(f::Py, arg0::TracedRArray, argNs::TracedRArray...; kw
1212
end
1313

1414
lowered = jax.jit(f).lower(inputs...)
15-
res = Reactant.Ops.hlo_call(pyconvert(String, lowered.as_text()), arg0, argNs...)
15+
res = @opcall hlo_call(pyconvert(String, lowered.as_text()), arg0, argNs...)
1616

1717
return length(res) == 0 ? nothing : res[1]
1818
end

ext/ReactantSpecialFunctionsExt.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@ module ReactantSpecialFunctionsExt
22
using SpecialFunctions
33
using Reactant: Ops, Reactant, TracedRNumber, ReactantFloat, ReactantInt, ReactantFloatInt
44
using Reactant.TracedRNumberOverrides: float
5+
using Reactant.Ops: @opcall
56

67
for fn in [:digamma, :erf, :erfc, (:loggamma, :lgamma)]
78
(fns, fno) = fn isa Tuple ? fn : (fn, fn)
89
@eval(function SpecialFunctions.$fns(x::TracedRNumber{<:ReactantFloatInt})
9-
return Ops.$fno(float(x))
10+
return @opcall $fno(float(x))
1011
end)
1112
end
1213

1314
function SpecialFunctions.gamma(x::TracedRNumber{<:ReactantFloat})
14-
return exp(Ops.lgamma(float(x)))
15+
return exp(@opcall(lgamma(float(x))))
1516
end
1617

1718
function SpecialFunctions.gamma(n::TracedRNumber{<:ReactantInt})
@@ -29,13 +30,14 @@ end
2930
# SpecialFunctions.invdigamma
3031

3132
function SpecialFunctions.trigamma(x::TracedRNumber{<:ReactantFloatInt})
32-
return Ops.polygamma(Ops.constant(Float64(1)), float(x))#TODO: change Ops definition
33+
#TODO: change Ops definition
34+
return @opcall(polygamma(@opcall(constant(Float64(1))), float(x)))
3335
end
3436

3537
function SpecialFunctions.polygamma(
3638
n::TracedRNumber{<:ReactantFloatInt}, x::TracedRNumber{<:ReactantFloatInt}
3739
)
38-
return Ops.polygamma(float(n), float(x))
40+
return @opcall polygamma(float(n), float(x))
3941
end
4042

4143
# SpecialFunctions.gamma_inc
@@ -112,7 +114,7 @@ end
112114
function SpecialFunctions.zeta(
113115
z::TracedRNumber{T}, s::TracedRNumber{T}
114116
) where {T<:ReactantFloatInt}
115-
return Ops.zeta(z, s)
117+
return @opcall zeta(z, s)
116118
end
117119

118120
end # module ReactantSpecialFunctionsExt

src/ConcreteRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,7 @@ function Base.fill(
750750
)
751751
output_shardings = Sharding.is_sharded(sharding) ? Dict(1 => sharding) : nothing
752752
fn = Reactant.compile((); output_shardings) do
753-
return Ops.fill(val, collect(Int64, dims))
753+
return @opcall fill(val, collect(Int64, dims))
754754
end
755755
return fn()
756756
end

0 commit comments

Comments
 (0)