Skip to content

Commit b2c80db

Browse files
authored
use FillArrays.OneElement in basisfunction (#444)
* use FillArrays.OneElement in basisfunction * ensure Float eltype
1 parent 2b7474f commit b2c80db

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

src/onehotvector.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# TODO: remove once FillArrays compat for <v1 is dropped
12
struct OneHotVector{T} <: AbstractVector{T}
23
n :: Int
34
len :: Int
@@ -9,13 +10,18 @@ struct OneHotVector{T} <: AbstractVector{T}
910
end
1011
end
1112
OneHotVector(n, len = n) = OneHotVector{Float64}(n, len)
12-
Base.size(v::OneHotVector) = (v.len,)
13-
Base.length(v::OneHotVector) = v.len
14-
function Base.getindex(v::OneHotVector{T}, i::Int) where {T}
13+
size(v::OneHotVector) = (v.len,)
14+
length(v::OneHotVector) = v.len
15+
function getindex(v::OneHotVector{T}, i::Int) where {T}
1516
i == v.n ? one(T) : zero(T)
1617
end
18+
@static if isdefined(FillArrays, :OneElement)
19+
const _OneElement = FillArrays.OneElement
20+
else
21+
const _OneElement = OneHotVector
22+
end
1723
# assume that the basis label starts at zero
1824
function basisfunction(sp, oneindex)
1925
oneindex >= 0 || throw(ArgumentError("index to set to one must be non-negative, received $oneindex"))
20-
Fun(sp, OneHotVector(oneindex))
26+
Fun(sp, _OneElement{Float64}(oneindex, oneindex))
2127
end

0 commit comments

Comments
 (0)