Skip to content

Commit 814e9c0

Browse files
authored
Merge pull request EnzymeAD#342 from EnzymeAD/ap/upsampling
* fix: manually zero out the lower triangular and upper triangular values * fix: only do it in tests * revert: change in Ops.cholesky * revert: remove unnecessary changes * fix: preserve parent array tracking for reshape * test: writing to a reshaped array * test: upsample_nearest * fix: test failures due to wrappers * fix: handle lazy transpose/adjoint correctly * fix: handle wrappers in NNlibExt correctly * fix: more reshaped wrappers handling * fix: dispatches to avoid ambiguity * fix: handle diagonal wrapper gracefully * fix: compile wrapped concrete array conversion to arrays * feat: more wrapped ConcreteRArray handling * chore: apply suggestions from code review * refactor: rearrange the tests * test: add test that fails on incorrect reshape dims ordering
2 parents 816e789 + 107040d commit 814e9c0

File tree

12 files changed

+388
-241
lines changed

12 files changed

+388
-241
lines changed

ext/ReactantArrayInterfaceExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ReactantArrayInterfaceExt
22

33
using ArrayInterface: ArrayInterface
44
using Reactant:
5-
Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray
5+
Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops
66

77
ArrayInterface.can_setindex(::Type{<:RArray}) = false
88
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false
@@ -14,7 +14,7 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where
1414
end
1515

1616
function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T}
17-
return reshape(vcat(x...), size(x))
17+
return Ops.reshape(vcat(x...), size(x)...)
1818
end
1919

2020
end

ext/ReactantNNlibExt.jl

Lines changed: 62 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@ module ReactantNNlibExt
33
using NNlib
44
using GPUArraysCore: @allowscalar
55
using Reactant:
6-
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
6+
Reactant,
7+
Ops,
8+
TracedRArray,
9+
AnyTracedRArray,
10+
materialize_traced_array,
11+
MLIR,
12+
TracedRNumber,
13+
get_mlir_data,
14+
set_mlir_data!
715
using ReactantCore: @trace
816
using LinearAlgebra: LinearAlgebra, triu
917

@@ -12,14 +20,7 @@ for (jlop, hloop) in (
1220
(:(NNlib.sigmoid_fast), :logistic),
1321
(:(NNlib.sigmoid), :logistic),
1422
)
15-
@eval function $(jlop)(x::TracedRNumber{T}) where {T}
16-
return TracedRNumber{T}(
17-
(),
18-
Reactant.MLIR.IR.result(
19-
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
20-
),
21-
)
22-
end
23+
@eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x)
2324
end
2425

2526
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
@@ -82,13 +83,6 @@ function NNlib.conv!(
8283
kernel_input_dim = N - 1
8384
kernel_output_dim = N
8485

85-
output_spatial_shapes = map(input_spatial_dims) do i
86-
K = kernel_size[i]
87-
pl, pr = padding[2i - 1], padding[2i]
88-
d = dilation[i]
89-
s = stride[i]
90-
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
91-
end
9286
output_batch_dim = input_batch_dim
9387
output_feature_dim = input_feature_dim
9488
output_spatial_dims = input_spatial_dims
@@ -119,8 +113,8 @@ function NNlib.conv!(
119113
end
120114

121115
conv = Reactant.MLIR.Dialects.stablehlo.convolution(
122-
x.mlir_data,
123-
weight.mlir_data;
116+
get_mlir_data(x),
117+
get_mlir_data(weight);
124118
result_0=result_type,
125119
window_strides=collect(stride),
126120
padding,
@@ -130,7 +124,7 @@ function NNlib.conv!(
130124
feature_group_count,
131125
batch_group_count=1,
132126
)
133-
y.mlir_data = Reactant.MLIR.IR.result(conv)
127+
set_mlir_data!(y, Reactant.MLIR.IR.result(conv))
134128
return y
135129
end
136130

@@ -165,7 +159,9 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
165159
output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N))
166160
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
167161

168-
unranked = Reactant.MLIR.IR.TensorType((), eltype(Reactant.MLIR.IR.type(x.mlir_data)))
162+
unranked = Reactant.MLIR.IR.TensorType(
163+
(), eltype(Reactant.MLIR.IR.type(get_mlir_data(x)))
164+
)
169165
body =
170166
let body = Reactant.MLIR.IR.Region(),
171167
loc = Reactant.MLIR.IR.Location(),
@@ -189,7 +185,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
189185
Reactant.MLIR.Dialects.stablehlo.constant(; value=attr)
190186
)
191187
reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window(
192-
[x.mlir_data],
188+
[get_mlir_data(x)],
193189
[init_value];
194190
result_0=[result_type],
195191
window_dimensions,
@@ -205,24 +201,24 @@ end
205201
function NNlib.maxpool!(
206202
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
207203
) where {T}
208-
y.mlir_data =
209-
reduce_window(
210-
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
211-
).mlir_data
204+
res = reduce_window(
205+
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
206+
)
207+
set_mlir_data!(y, get_mlir_data(res))
212208
return y
213209
end
214210

215211
function NNlib.meanpool!(
216212
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
217213
) where {T}
218214
res = reduce_window(Reactant.MLIR.Dialects.stablehlo.add, T.(x), pdims; init=zero(T))
219-
y.mlir_data = (res ./ T(prod(NNlib.kernel_size(pdims)))).mlir_data
215+
set_mlir_data!(y, get_mlir_data(res ./ T(prod(NNlib.kernel_size(pdims)))))
220216
return y
221217
end
222218

223-
NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
219+
NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = PermutedDimsArray(x, (2, 1, 3))
224220
function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T}
225-
y = permutedims(x, (2, 1, 3))
221+
y = NNlib.batched_transpose(x)
226222
conj!(y)
227223
return y
228224
end
@@ -238,64 +234,47 @@ function NNlib.batched_mul!(
238234
),
239235
)
240236
end
237+
238+
if size(x, 3) != size(y, 3)
239+
B = max(size(x, 3), size(y, 3))
240+
if size(x, 3) == 1
241+
x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
242+
elseif size(y, 3) == 1
243+
y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
244+
end
245+
end
246+
241247
x = permutedims(x, (3, 1, 2))
242248
y = permutedims(y, (3, 1, 2))
243249

244-
B = max(size(x, 1), size(y, 1))
245-
out_shape = (B, size(x, 2), size(y, 3))
246-
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data)))
247-
248250
if size(x, 1) != size(y, 1)
251+
B = max(size(x, 1), size(y, 1))
249252
if size(x, 1) == 1
250253
x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
251254
elseif size(y, 1) == 1
252255
y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
253256
end
254257
end
255258

256-
dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
257-
MLIR.IR.context(), 1, [0], 1, [0], 1, [2], 1, [1]
259+
tmp = Ops.dot_general(
260+
T1.(materialize_traced_array(x)),
261+
T1.(materialize_traced_array(y));
262+
contracting_dimensions=([3], [2]),
263+
batching_dimensions=([1], [1]),
258264
)
265+
set_mlir_data!(res, get_mlir_data(permutedims(tmp, (2, 3, 1))))
259266

260-
prec = MLIR.IR.Attribute(
261-
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
262-
)
263-
tmp = TracedRArray{T1,3}(
264-
(),
265-
MLIR.IR.result(
266-
MLIR.Dialects.stablehlo.dot_general(
267-
x.mlir_data,
268-
y.mlir_data;
269-
result_0=resty,
270-
dot_dimension_numbers=dot_dimension_numbers,
271-
precision_config=prec,
272-
),
273-
1,
274-
),
275-
size(resty),
276-
)
277-
res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data
278267
return res
279268
end
280269

281270
function NNlib.pad_constant(
282-
x::TracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
271+
x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
283272
) where {T,N}
284273
value = Reactant.promote_to(TracedRNumber{T}, value)
285-
edge_padding_low = [i[1] for i in pad]
286-
edge_padding_high = [i[2] for i in pad]
287-
interior_padding = [0 for i in pad]
288-
res = MLIR.IR.result(
289-
MLIR.Dialects.stablehlo.pad(
290-
x.mlir_data,
291-
value.mlir_data;
292-
edge_padding_low,
293-
edge_padding_high,
294-
interior_padding,
295-
),
296-
1,
297-
)
298-
return TracedRArray{T,N}((), res, size(MLIR.IR.type(res)))
274+
low = [i[1] for i in pad]
275+
high = [i[2] for i in pad]
276+
interior = [0 for i in pad]
277+
return Ops.pad(materialize_traced_array(x), value; low, high, interior)
299278
end
300279

301280
# XXX: reevaluate this manual optimization once
@@ -305,7 +284,7 @@ function NNlib.gather!(
305284
src::AnyTracedRArray{T2,2},
306285
idxs::Union{AbstractUnitRange{<:Number}},
307286
) where {T1,T2}
308-
dst.mlir_data = src[:, idxs].mlir_data
287+
set_mlir_data!(dst, get_mlir_data(src[:, idxs]))
309288
return dst
310289
end
311290

@@ -314,8 +293,8 @@ function NNlib.gather!(
314293
) where {T1,T2}
315294
dims = NNlib.scatter_dims(src, dst, idxs)
316295
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function
317-
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data
318-
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data
296+
idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1)
297+
slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]))
319298

320299
#! format: off
321300
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
@@ -331,11 +310,11 @@ function NNlib.gather!(
331310

332311
res = MLIR.IR.result(
333312
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
334-
src.mlir_data, idxs, slice_sizes; dimension_numbers
313+
get_mlir_data(src), idxs, slice_sizes; dimension_numbers
335314
),
336315
1,
337316
)
338-
dst.mlir_data = res
317+
set_mlir_data!(dst, res)
339318
return dst
340319
end
341320

@@ -354,7 +333,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
354333
return reshape(res, start_sizes..., :)
355334
end
356335
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
357-
dst.mlir_data = res.mlir_data
336+
set_mlir_data!(dst, get_mlir_data(res))
358337
return dst
359338
end
360339

@@ -363,7 +342,7 @@ dilate_shape(s, d) = max(0, 1 + d * (s - 1))
363342
# see lax._conv_general_dilated_transpose_rhs
364343
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495
365344
function NNlib.∇conv_filter!(
366-
dw::Reactant.TracedRArray{T,N},
345+
dw::TracedRArray{T,N},
367346
x::AnyTracedRArray,
368347
dy::AnyTracedRArray,
369348
cdims::NNlib.DenseConvDims,
@@ -437,8 +416,8 @@ function NNlib.∇conv_filter!(
437416

438417
result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T))
439418
conv = MLIR.Dialects.stablehlo.convolution(
440-
x.mlir_data,
441-
dy.mlir_data;
419+
get_mlir_data(x),
420+
get_mlir_data(dy);
442421
result_0=result_type,
443422
window_strides=collect(dilation),
444423
padding,
@@ -447,11 +426,12 @@ function NNlib.∇conv_filter!(
447426
feature_group_count,
448427
batch_group_count,
449428
)
450-
451-
dw.mlir_data = MLIR.IR.result(conv)
429+
set_mlir_data!(dw, MLIR.IR.result(conv))
452430

453431
if !NNlib.flipkernel(cdims)
454-
dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data
432+
set_mlir_data!(
433+
dw, get_mlir_data(Reactant.Ops.reverse(dw; dimensions=output_spatial_dims))
434+
)
455435
end
456436

457437
return dw
@@ -553,8 +533,8 @@ function NNlib.∇conv_data!(
553533
end
554534

555535
conv = MLIR.Dialects.stablehlo.convolution(
556-
dy.mlir_data,
557-
w.mlir_data;
536+
get_mlir_data(dy),
537+
get_mlir_data(w);
558538
result_0=result_type,
559539
window_strides=1,
560540
padding,
@@ -564,8 +544,7 @@ function NNlib.∇conv_data!(
564544
feature_group_count,
565545
batch_group_count=1,
566546
)
567-
568-
dx.mlir_data = MLIR.IR.result(conv)
547+
set_mlir_data!(dx, MLIR.IR.result(conv))
569548

570549
return dx
571550
end

0 commit comments

Comments
 (0)