|
157 | 157 | end
|
158 | 158 | end
|
159 | 159 |
|
160 |
| -@generated function ifelse( |
161 |
| - m::AbstractMask, |
| 160 | +@generated function _ifelse( |
| 161 | + m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}}, |
162 | 162 | x::ForwardDiff.Dual{TAG,V,P},
|
163 | 163 | y::ForwardDiff.Dual{TAG,V,P}
|
164 | 164 | ) where {TAG,V,P}
|
|
171 | 171 | ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
|
172 | 172 | end
|
173 | 173 | end
|
174 |
| -@generated function ifelse( |
175 |
| - m::AbstractMask, |
| 174 | +@generated function _ifelse( |
| 175 | + m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}}, |
176 | 176 | x::Number,
|
177 | 177 | y::ForwardDiff.Dual{TAG,V,P}
|
178 | 178 | ) where {TAG,V,P}
|
|
184 | 184 | ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
|
185 | 185 | end
|
186 | 186 | end
|
187 |
| -@generated function ifelse( |
188 |
| - m::AbstractMask, |
| 187 | +@generated function _ifelse( |
| 188 | + m::Union{AbstractMask,VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}}, |
189 | 189 | x::ForwardDiff.Dual{TAG,V,P},
|
190 | 190 | y::Number
|
191 | 191 | ) where {TAG,V,P}
|
|
197 | 197 | ForwardDiff.Dual{$TAG}(z, ForwardDiff.Partials(p))
|
198 | 198 | end
|
199 | 199 | end
|
| 200 | +@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::Number) = _ifelse(m, x, y) |
| 201 | +@inline ifelse(m::AbstractMask, x::ForwardDiff.Dual, y::ForwardDiff.Dual) = _ifelse(m, x, y) |
| 202 | +@inline ifelse(m::AbstractMask, y::Number, x::ForwardDiff.Dual) = _ifelse(m, y, x) |
| 203 | + |
| 204 | +@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, x::ForwardDiff.Dual, y::Number) = _ifelse(m, x, y) |
| 205 | +@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, x::ForwardDiff.Dual, y::ForwardDiff.Dual) = _ifelse(m, x, y) |
| 206 | +@inline ifelse(m::VecUnroll{<:Any,<:Any,Bit,<:AbstractMask}, y::Number, x::ForwardDiff.Dual) = _ifelse(m, y, x) |
| 207 | + |
200 | 208 | @inline function SLEEFPirates.softplus(x::ForwardDiff.Dual{TAG}) where {TAG}
|
201 | 209 | val = ForwardDiff.value(x)
|
202 | 210 | expx = exp(val)
|
|
0 commit comments