Skip to content

Commit 83ccc37

Browse files
authored
more generic ZScoreTransform, UnitRangeTranform to support CuArrays (#622)
1 parent 854a541 commit 83ccc37

File tree

1 file changed

+27
-44
lines changed

1 file changed

+27
-44
lines changed

src/transformations.jl

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,18 @@ reconstruct(t::AbstractDataTransform, y::AbstractVector{<:Real}) =
4949
"""
5050
Standardization (Z-score transformation)
5151
"""
52-
struct ZScoreTransform{T<:Real} <: AbstractDataTransform
52+
struct ZScoreTransform{T<:Real, U<:AbstractVector{T}} <: AbstractDataTransform
5353
len::Int
5454
dims::Int
55-
mean::Vector{T}
56-
scale::Vector{T}
55+
mean::U
56+
scale::U
5757

58-
function ZScoreTransform(l::Int, dims::Int, m::Vector{T}, s::Vector{T}) where T
58+
function ZScoreTransform(l::Int, dims::Int, m::U, s::U) where {T<:Real, U<:AbstractVector{T}}
5959
lenm = length(m)
6060
lens = length(s)
6161
lenm == l || lenm == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
6262
lens == l || lens == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
63-
new{T}(l, dims, m, s)
63+
new{T, U}(l, dims, m, s)
6464
end
6565
end
6666

@@ -123,9 +123,8 @@ function fit(::Type{ZScoreTransform}, X::AbstractMatrix{<:Real};
123123
else
124124
throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
125125
end
126-
T = eltype(X)
127-
return ZScoreTransform(l, dims, (center ? vec(m) : zeros(T, 0)),
128-
(scale ? vec(s) : zeros(T, 0)))
126+
return ZScoreTransform(l, dims, (center ? vec(m) : similar(m, 0)),
127+
(scale ? vec(s) : similar(s, 0)))
129128
end
130129

131130
function fit(::Type{ZScoreTransform}, X::AbstractVector{<:Real};
@@ -134,10 +133,7 @@ function fit(::Type{ZScoreTransform}, X::AbstractVector{<:Real};
134133
throw(DomainError(dims, "fit only accepts dims=1 over a vector. Try fit(t, x, dims=1)."))
135134
end
136135

137-
T = eltype(X)
138-
m, s = mean_and_std(X)
139-
return ZScoreTransform(1, dims, (center ? [m] : zeros(T, 0)),
140-
(scale ? [s] : zeros(T, 0)))
136+
return fit(ZScoreTransform, reshape(X, :, 1); dims=dims, center=center, scale=scale)
141137
end
142138

143139
function transform!(y::AbstractMatrix{<:Real}, t::ZScoreTransform, x::AbstractMatrix{<:Real})
@@ -207,19 +203,19 @@ end
207203
"""
208204
Unit range normalization
209205
"""
210-
struct UnitRangeTransform{T<:Real} <: AbstractDataTransform
206+
struct UnitRangeTransform{T<:Real, U<:AbstractVector} <: AbstractDataTransform
211207
len::Int
212208
dims::Int
213209
unit::Bool
214-
min::Vector{T}
215-
scale::Vector{T}
210+
min::U
211+
scale::U
216212

217-
function UnitRangeTransform(l::Int, dims::Int, unit::Bool, min::Vector{T}, max::Vector{T}) where {T}
213+
function UnitRangeTransform(l::Int, dims::Int, unit::Bool, min::U, max::U) where {T, U<:AbstractVector{T}}
218214
lenmin = length(min)
219215
lenmax = length(max)
220216
lenmin == l || lenmin == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
221217
lenmax == l || lenmax == 0 || throw(DimensionMismatch("Inconsistent dimensions."))
222-
new{T}(l, dims, unit, min, max)
218+
new{T, U}(l, dims, unit, min, max)
223219
end
224220
end
225221

@@ -270,45 +266,32 @@ function fit(::Type{UnitRangeTransform}, X::AbstractMatrix{<:Real};
270266
Base.depwarn("fit(t, x) is deprecated: use fit(t, x, dims=2) instead", :fit)
271267
dims = 2
272268
end
273-
if dims == 1
274-
l, tmin, tmax = _compute_extrema(X)
275-
elseif dims == 2
276-
l, tmin, tmax = _compute_extrema(X')
277-
else
278-
throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
279-
end
280-
281-
for i = 1:l
282-
@inbounds tmax[i] = 1 / (tmax[i] - tmin[i])
283-
end
269+
dims (1, 2) || throw(DomainError(dims, "fit only accept dims to be 1 or 2."))
270+
tmin, tmax = _compute_extrema(X, dims)
271+
@. tmax = 1 / (tmax - tmin)
272+
l = length(tmin)
284273
return UnitRangeTransform(l, dims, unit, tmin, tmax)
285274
end
286275

287-
function _compute_extrema(X::AbstractMatrix{<:Real})
288-
n, l = size(X)
289-
tmin = X[1, :]
290-
tmax = X[1, :]
291-
for j = 1:l
292-
@inbounds for i = 2:n
293-
if X[i, j] < tmin[j]
294-
tmin[j] = X[i, j]
295-
elseif X[i, j] > tmax[j]
296-
tmax[j] = X[i, j]
297-
end
298-
end
276+
function _compute_extrema(X::AbstractMatrix, dims::Integer)
277+
dims == 2 && return _compute_extrema(X', 1)
278+
l = size(X, 2)
279+
tmin = similar(X, l)
280+
tmax = similar(X, l)
281+
for i in 1:l
282+
@inbounds tmin[i], tmax[i] = extrema(@view(X[:, i]))
299283
end
300-
return l, tmin, tmax
284+
return tmin, tmax
301285
end
302286

303287
function fit(::Type{UnitRangeTransform}, X::AbstractVector{<:Real};
304288
dims::Integer=1, unit::Bool=true)
305289
if dims != 1
306290
throw(DomainError(dims, "fit only accept dims=1 over a vector. Try fit(t, x, dims=1)."))
307291
end
308-
309-
l, tmin, tmax = _compute_extrema(reshape(X, :, 1))
292+
tmin, tmax = extrema(X)
310293
tmax = 1 / (tmax - tmin)
311-
return UnitRangeTransform(1, dims, unit, vec(tmin), vec(tmax))
294+
return UnitRangeTransform(1, dims, unit, [tmin], [tmax])
312295
end
313296

314297
function transform!(y::AbstractMatrix{<:Real}, t::UnitRangeTransform, x::AbstractMatrix{<:Real})

0 commit comments

Comments
 (0)