Skip to content

Commit 699fded

Browse files
committed
feat: also use pretty printing for regular show
1 parent c0da506 commit 699fded

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

src/ExtensionInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ function Base.show(io::IO, g::ZygoteGradient{F,degree,arg}) where {F,degree,arg}
2323
print(io, g.op)
2424
return nothing
2525
end
26+
Base.show(io::IO, ::MIME"text/plain", g::ZygoteGradient) = show(io, g)
2627

2728
function _zygote_gradient(args...)
2829
return error("Please load the Zygote.jl package.")

test/test_zygote_gradient_wrapper.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
first_partial = _zygote_gradient(log, Val(2), Val(1))
2626
nested = _zygote_gradient(first_partial, Val(1))
2727
@test repr(nested) == "∂∂₁log"
28+
29+
# Also should work with text/plain
30+
@test repr(MIME"text/plain", nested) == "∂∂₁log"
2831
end
2932

3033
@testitem "ZygoteGradient evaluation" begin

0 commit comments

Comments
 (0)