Skip to content

Commit dae09ef

Browse files
authored
fix: replace one with oneunit for basis computation (#826)
* fix: replace `one` with `oneunit` for basis computation * Replace one with oneunit everywhere * Changelog * cov
1 parent ea73473 commit dae09ef

File tree

13 files changed

+74
-55
lines changed

13 files changed

+74
-55
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Fixed
11+
12+
- Replace `one` with `oneunit` in basis computation ([#826])
13+
1014
## [0.7.3]
1115

1216
### Fixed
@@ -62,6 +66,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6266
[0.6.54]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.53...DifferentiationInterface-v0.6.54
6367
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
6468

69+
[#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826
6570
[#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823
6671
[#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818
6772
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Aqua = "0.8.12"
6161
ChainRulesCore = "1.23.0"
6262
ComponentArrays = "0.15.27"
6363
DataFrames = "1.7.0"
64+
Dates = "1"
6465
DiffResults = "1.1.0"
6566
Diffractor = "=0.2.6"
6667
Enzyme = "0.13.39"
@@ -98,6 +99,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
9899
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99100
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
100101
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
102+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
101103
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
102104
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
103105
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
@@ -130,6 +132,7 @@ test = [
130132
"Aqua",
131133
"ComponentArrays",
132134
"DataFrames",
135+
"Dates",
133136
"ExplicitImports",
134137
"JET",
135138
"JLArrays",

DifferentiationInterface/docs/src/explanation/operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,4 @@ For same-point preparation, the same rules hold with two modifications:
152152

153153
!!! warning
154154
These rules hold for the majority of backends, but there are some exceptions.
155-
The most important exception is [ReverseDiff](@ref) and its taping mechanism, which is sensitive to control flow inside the function.
155+
The most important exception is [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) and its taping mechanism, which is sensitive to control flow inside the function.

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,27 +189,27 @@ end
189189
function DI.value_and_derivative(
190190
f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
191191
) where {F,C}
192-
y, ty = DI.value_and_pushforward(f, backend, x, (one(x),), contexts...)
192+
y, ty = DI.value_and_pushforward(f, backend, x, (oneunit(x),), contexts...)
193193
return y, only(ty)
194194
end
195195

196196
function DI.value_and_derivative!(
197197
f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
198198
) where {F,C}
199-
y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (one(x),), contexts...)
199+
y, _ = DI.value_and_pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...)
200200
return y, der
201201
end
202202

203203
function DI.derivative(
204204
f::F, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
205205
) where {F,C}
206-
return only(DI.pushforward(f, backend, x, (one(x),), contexts...))
206+
return only(DI.pushforward(f, backend, x, (oneunit(x),), contexts...))
207207
end
208208

209209
function DI.derivative!(
210210
f::F, der, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context,C}
211211
) where {F,C}
212-
DI.pushforward!(f, (der,), backend, x, (one(x),), contexts...)
212+
DI.pushforward!(f, (der,), backend, x, (oneunit(x),), contexts...)
213213
return der
214214
end
215215

@@ -220,7 +220,7 @@ function DI.prepare_derivative_nokwarg(
220220
) where {F,C}
221221
_sig = DI.signature(f, backend, x, contexts...; strict)
222222
pushforward_prep = DI.prepare_pushforward_nokwarg(
223-
strict, f, backend, x, (one(x),), contexts...
223+
strict, f, backend, x, (oneunit(x),), contexts...
224224
)
225225
return ForwardDiffOneArgDerivativePrep(_sig, pushforward_prep)
226226
end
@@ -234,7 +234,7 @@ function DI.value_and_derivative(
234234
) where {F,C}
235235
DI.check_prep(f, prep, backend, x, contexts...)
236236
y, ty = DI.value_and_pushforward(
237-
f, prep.pushforward_prep, backend, x, (one(x),), contexts...
237+
f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
238238
)
239239
return y, only(ty)
240240
end
@@ -249,7 +249,7 @@ function DI.value_and_derivative!(
249249
) where {F,C}
250250
DI.check_prep(f, prep, backend, x, contexts...)
251251
y, _ = DI.value_and_pushforward!(
252-
f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
252+
f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
253253
)
254254
return y, der
255255
end
@@ -263,7 +263,7 @@ function DI.derivative(
263263
) where {F,C}
264264
DI.check_prep(f, prep, backend, x, contexts...)
265265
return only(
266-
DI.pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...)
266+
DI.pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
267267
)
268268
end
269269

@@ -276,7 +276,9 @@ function DI.derivative!(
276276
contexts::Vararg{DI.Context,C},
277277
) where {F,C}
278278
DI.check_prep(f, prep, backend, x, contexts...)
279-
DI.pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
279+
DI.pushforward!(
280+
f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
281+
)
280282
return der
281283
end
282284

@@ -638,9 +640,9 @@ function DI.second_derivative(
638640
) where {F,C}
639641
DI.check_prep(f, prep, backend, x, contexts...)
640642
T = tag_type(f, backend, x)
641-
xdual = make_dual(T, x, one(x))
643+
xdual = make_dual(T, x, oneunit(x))
642644
T2 = tag_type(f, backend, xdual)
643-
xdual2 = make_dual(T2, xdual, one(xdual))
645+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
644646
contexts_dual = translate(typeof(xdual2), contexts)
645647
ydual = f(xdual2, contexts_dual...)
646648
return myderivative(T, myderivative(T2, ydual))
@@ -656,9 +658,9 @@ function DI.second_derivative!(
656658
) where {F,C}
657659
DI.check_prep(f, prep, backend, x, contexts...)
658660
T = tag_type(f, backend, x)
659-
xdual = make_dual(T, x, one(x))
661+
xdual = make_dual(T, x, oneunit(x))
660662
T2 = tag_type(f, backend, xdual)
661-
xdual2 = make_dual(T2, xdual, one(xdual))
663+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
662664
contexts_dual = translate(typeof(xdual2), contexts)
663665
ydual = f(xdual2, contexts_dual...)
664666
return myderivative!(T, der2, myderivative(T2, ydual))
@@ -673,9 +675,9 @@ function DI.value_derivative_and_second_derivative(
673675
) where {F,C}
674676
DI.check_prep(f, prep, backend, x, contexts...)
675677
T = tag_type(f, backend, x)
676-
xdual = make_dual(T, x, one(x))
678+
xdual = make_dual(T, x, oneunit(x))
677679
T2 = tag_type(f, backend, xdual)
678-
xdual2 = make_dual(T2, xdual, one(xdual))
680+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
679681
contexts_dual = translate(typeof(xdual2), contexts)
680682
ydual = f(xdual2, contexts_dual...)
681683
y = myvalue(T, myvalue(T2, ydual))
@@ -695,9 +697,9 @@ function DI.value_derivative_and_second_derivative!(
695697
) where {F,C}
696698
DI.check_prep(f, prep, backend, x, contexts...)
697699
T = tag_type(f, backend, x)
698-
xdual = make_dual(T, x, one(x))
700+
xdual = make_dual(T, x, oneunit(x))
699701
T2 = tag_type(f, backend, xdual)
700-
xdual2 = make_dual(T2, xdual, one(xdual))
702+
xdual2 = make_dual(T2, xdual, oneunit(xdual))
701703
contexts_dual = translate(typeof(xdual2), contexts)
702704
ydual = f(xdual2, contexts_dual...)
703705
y = myvalue(T, myvalue(T2, ydual))
@@ -756,7 +758,7 @@ function DI.value_gradient_and_hessian!(
756758
contexts isa NTuple{C,DI.GeneralizedConstant}
757759
)
758760
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
759-
result = DiffResult(one(eltype(x)), (grad, hess))
761+
result = DiffResult(oneunit(eltype(x)), (grad, hess))
760762
result = hessian!(result, fc, x)
761763
y = DR.value(result)
762764
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
@@ -855,7 +857,7 @@ function DI.value_gradient_and_hessian!(
855857
DI.check_prep(f, prep, backend, x, contexts...)
856858
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
857859
fc = DI.fix_tail(f, contexts_dual...)
858-
result = DiffResult(one(eltype(x)), (grad, hess))
860+
result = DiffResult(oneunit(eltype(x)), (grad, hess))
859861
CHK = tag_type(backend) === Nothing
860862
if CHK
861863
checktag(prep.result_config, f, x)

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ using GPUArraysCore: @allowscalar, AbstractGPUArray
66
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
77
b = similar(a)
88
fill!(b, zero(T))
9-
@allowscalar b[i] = one(T)
9+
@allowscalar b[i] = oneunit(T)
1010
return b
1111
end
1212

1313
function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
1414
b = similar(a)
1515
fill!(b, zero(T))
16-
view(b, inds) .= one(T)
16+
view(b, inds) .= oneunit(T)
1717
return b
1818
end
1919

DifferentiationInterface/src/first_order/derivative.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function prepare_derivative_nokwarg(
143143
) where {F,C}
144144
_sig = signature(f, backend, x, contexts...; strict)
145145
pushforward_prep = prepare_pushforward_nokwarg(
146-
strict, f, backend, x, (one(x),), contexts...
146+
strict, f, backend, x, (oneunit(x),), contexts...
147147
)
148148
return PushforwardDerivativePrep(_sig, pushforward_prep)
149149
end
@@ -153,7 +153,7 @@ function prepare_derivative_nokwarg(
153153
) where {F,C}
154154
_sig = signature(f!, y, backend, x, contexts...; strict)
155155
pushforward_prep = prepare_pushforward_nokwarg(
156-
strict, f!, y, backend, x, (one(x),), contexts...
156+
strict, f!, y, backend, x, (oneunit(x),), contexts...
157157
)
158158
return PushforwardDerivativePrep(_sig, pushforward_prep)
159159
end
@@ -169,7 +169,7 @@ function value_and_derivative(
169169
) where {F,C}
170170
check_prep(f, prep, backend, x, contexts...)
171171
y, ty = value_and_pushforward(
172-
f, prep.pushforward_prep, backend, x, (one(x),), contexts...
172+
f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
173173
)
174174
return y, only(ty)
175175
end
@@ -184,7 +184,7 @@ function value_and_derivative!(
184184
) where {F,C}
185185
check_prep(f, prep, backend, x, contexts...)
186186
y, _ = value_and_pushforward!(
187-
f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
187+
f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
188188
)
189189
return y, der
190190
end
@@ -197,7 +197,7 @@ function derivative(
197197
contexts::Vararg{Context,C},
198198
) where {F,C}
199199
check_prep(f, prep, backend, x, contexts...)
200-
ty = pushforward(f, prep.pushforward_prep, backend, x, (one(x),), contexts...)
200+
ty = pushforward(f, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
201201
return only(ty)
202202
end
203203

@@ -210,7 +210,7 @@ function derivative!(
210210
contexts::Vararg{Context,C},
211211
) where {F,C}
212212
check_prep(f, prep, backend, x, contexts...)
213-
pushforward!(f, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
213+
pushforward!(f, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
214214
return der
215215
end
216216

@@ -226,7 +226,7 @@ function value_and_derivative(
226226
) where {F,C}
227227
check_prep(f!, y, prep, backend, x, contexts...)
228228
y, ty = value_and_pushforward(
229-
f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...
229+
f!, y, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
230230
)
231231
return y, only(ty)
232232
end
@@ -242,7 +242,7 @@ function value_and_derivative!(
242242
) where {F,C}
243243
check_prep(f!, y, prep, backend, x, contexts...)
244244
y, _ = value_and_pushforward!(
245-
f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...
245+
f!, y, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
246246
)
247247
return y, der
248248
end
@@ -256,7 +256,7 @@ function derivative(
256256
contexts::Vararg{Context,C},
257257
) where {F,C}
258258
check_prep(f!, y, prep, backend, x, contexts...)
259-
ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (one(x),), contexts...)
259+
ty = pushforward(f!, y, prep.pushforward_prep, backend, x, (oneunit(x),), contexts...)
260260
return only(ty)
261261
end
262262

@@ -270,7 +270,9 @@ function derivative!(
270270
contexts::Vararg{Context,C},
271271
) where {F,C}
272272
check_prep(f!, y, prep, backend, x, contexts...)
273-
pushforward!(f!, y, (der,), prep.pushforward_prep, backend, x, (one(x),), contexts...)
273+
pushforward!(
274+
f!, y, (der,), prep.pushforward_prep, backend, x, (oneunit(x),), contexts...
275+
)
274276
return der
275277
end
276278

DifferentiationInterface/src/first_order/gradient.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function prepare_gradient_nokwarg(
9191
_sig = signature(f, backend, x, contexts...; strict)
9292
y = f(x, map(unwrap, contexts)...) # TODO: replace with output type inference?
9393
pullback_prep = prepare_pullback_nokwarg(
94-
strict, f, backend, x, (one(typeof(y)),), contexts...
94+
strict, f, backend, x, (oneunit(typeof(y)),), contexts...
9595
)
9696
return PullbackGradientPrep(_sig, y, pullback_prep)
9797
end
@@ -106,7 +106,9 @@ function value_and_gradient(
106106
contexts::Vararg{Context,C},
107107
) where {F,SIG,Y,C}
108108
check_prep(f, prep, backend, x, contexts...)
109-
y, tx = value_and_pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...)
109+
y, tx = value_and_pullback(
110+
f, prep.pullback_prep, backend, x, (oneunit(Y),), contexts...
111+
)
110112
return y, only(tx)
111113
end
112114

@@ -120,7 +122,7 @@ function value_and_gradient!(
120122
) where {F,SIG,Y,C}
121123
check_prep(f, prep, backend, x, contexts...)
122124
y, _ = value_and_pullback!(
123-
f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...
125+
f, (grad,), prep.pullback_prep, backend, x, (oneunit(Y),), contexts...
124126
)
125127
return y, grad
126128
end
@@ -133,7 +135,7 @@ function gradient(
133135
contexts::Vararg{Context,C},
134136
) where {F,SIG,Y,C}
135137
check_prep(f, prep, backend, x, contexts...)
136-
tx = pullback(f, prep.pullback_prep, backend, x, (one(Y),), contexts...)
138+
tx = pullback(f, prep.pullback_prep, backend, x, (oneunit(Y),), contexts...)
137139
return only(tx)
138140
end
139141

@@ -146,7 +148,7 @@ function gradient!(
146148
contexts::Vararg{Context,C},
147149
) where {F,SIG,Y,C}
148150
check_prep(f, prep, backend, x, contexts...)
149-
pullback!(f, (grad,), prep.pullback_prep, backend, x, (one(Y),), contexts...)
151+
pullback!(f, (grad,), prep.pullback_prep, backend, x, (oneunit(Y),), contexts...)
150152
return grad
151153
end
152154

0 commit comments

Comments
 (0)