Skip to content

Commit 52a1bd9

Browse files
authored
Allow Zygote v0.7 (#188)
* Update TullioChainRulesCoreExt.jl * Update Project.toml
1 parent c3cf714 commit 52a1bd9

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Tullio"
22
uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
33
authors = ["Michael Abbott"]
4-
version = "0.3.8"
4+
version = "0.3.9"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -29,13 +29,13 @@ FillArrays = "0.11, 0.12, 0.13, 1"
2929
ForwardDiff = "0.10, 1.0"
3030
KernelAbstractions = "0.9"
3131
LoopVectorization = "0.12.101"
32-
NamedDims = "0.2"
32+
NamedDims = "0.2, 1"
3333
OffsetArrays = "1"
3434
Requires = "1"
3535
TensorOperations = "4, 5"
3636
Tracker = "0.2"
3737
VectorizationBase = "0.21.23"
38-
Zygote = "0.6.33"
38+
Zygote = "0.6.33, 0.7"
3939
julia = "1.10"
4040

4141
[extras]

ext/TullioChainRulesCoreExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function ChainRulesCore.rrule(ev::Tullio.Eval, args...)
66
Z = ev.fwd(args...)
77
Z, function tullio_back(Δ)
88
isnothing(ev.rev) && error("no gradient definition here!")
9-
dxs = map(ev.rev(Δ, Z, args...)) do dx
9+
dxs = map(ev.rev(unthunk(Δ), Z, args...)) do dx
1010
dx === nothing ? ChainRulesCore.ZeroTangent() : dx
1111
end
1212
tuple(ChainRulesCore.ZeroTangent(), dxs...)

0 commit comments

Comments
 (0)