Skip to content

Commit 2c3e48d

Browse files
committed
add subtract rule
1 parent b3f9f9b commit 2c3e48d

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,16 @@ function rrule(::typeof(-), x::AbstractArray)
434434
return -x, negation_pullback
435435
end
436436

437+
#####
438+
##### Subtraction
439+
#####
440+
441+
frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x-y, Δx-Δy
442+
443+
function rrule(::typeof(-), x::AbstractArray, y::AbstractArray)
444+
subtract_pullback(dy) = (NoTangent(), dy, -dy)
445+
return x-y, subtract_pullback
446+
end
437447

438448
#####
439449
##### Addition (Multiarg `+`)

test/rulesets/Base/arraymath.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,4 +217,12 @@
217217
@gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4))
218218
@gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1))
219219
end
220+
221+
@testset "subtraction" begin
222+
# fwd
223+
@gpu test_frule(-, randn(2), randn(2))
224+
# rev
225+
@gpu test_rrule(-, randn(4, 4), randn(4, 4))
226+
@gpu test_rrule(-, randn(3), randn(3,1))
227+
end
220228
end

0 commit comments

Comments
 (0)