@@ -49,18 +49,18 @@ reconstruct(t::AbstractDataTransform, y::AbstractVector{<:Real}) =
4949"""
5050Standardization (Z-score transformation)
5151"""
52- struct ZScoreTransform{T<: Real } <: AbstractDataTransform
52+ struct ZScoreTransform{T<: Real , U <: AbstractVector{T} } <: AbstractDataTransform
5353 len:: Int
5454 dims:: Int
55- mean:: Vector{T}
56- scale:: Vector{T}
55+ mean:: U
56+ scale:: U
5757
58- function ZScoreTransform (l:: Int , dims:: Int , m:: Vector{T} , s:: Vector{T} ) where T
58+ function ZScoreTransform (l:: Int , dims:: Int , m:: U , s:: U ) where {T <: Real , U <: AbstractVector{T} }
5959 lenm = length (m)
6060 lens = length (s)
6161 lenm == l || lenm == 0 || throw (DimensionMismatch (" Inconsistent dimensions." ))
6262 lens == l || lens == 0 || throw (DimensionMismatch (" Inconsistent dimensions." ))
63- new {T} (l, dims, m, s)
63+ new {T, U } (l, dims, m, s)
6464 end
6565end
6666
@@ -123,9 +123,8 @@ function fit(::Type{ZScoreTransform}, X::AbstractMatrix{<:Real};
123123 else
124124 throw (DomainError (dims, " fit only accept dims to be 1 or 2." ))
125125 end
126- T = eltype (X)
127- return ZScoreTransform (l, dims, (center ? vec (m) : zeros (T, 0 )),
128- (scale ? vec (s) : zeros (T, 0 )))
126+ return ZScoreTransform (l, dims, (center ? vec (m) : similar (m, 0 )),
127+ (scale ? vec (s) : similar (s, 0 )))
129128end
130129
131130function fit (:: Type{ZScoreTransform} , X:: AbstractVector{<:Real} ;
@@ -134,10 +133,7 @@ function fit(::Type{ZScoreTransform}, X::AbstractVector{<:Real};
134133 throw (DomainError (dims, " fit only accepts dims=1 over a vector. Try fit(t, x, dims=1)." ))
135134 end
136135
137- T = eltype (X)
138- m, s = mean_and_std (X)
139- return ZScoreTransform (1 , dims, (center ? [m] : zeros (T, 0 )),
140- (scale ? [s] : zeros (T, 0 )))
136+ return fit (ZScoreTransform, reshape (X, :, 1 ); dims= dims, center= center, scale= scale)
141137end
142138
143139function transform! (y:: AbstractMatrix{<:Real} , t:: ZScoreTransform , x:: AbstractMatrix{<:Real} )
@@ -207,19 +203,19 @@ end
207203"""
208204Unit range normalization
209205"""
210- struct UnitRangeTransform{T<: Real } <: AbstractDataTransform
206+ struct UnitRangeTransform{T<: Real , U <: AbstractVector } <: AbstractDataTransform
211207 len:: Int
212208 dims:: Int
213209 unit:: Bool
214- min:: Vector{T}
215- scale:: Vector{T}
210+ min:: U
211+ scale:: U
216212
217- function UnitRangeTransform (l:: Int , dims:: Int , unit:: Bool , min:: Vector{T} , max:: Vector{T} ) where {T}
213+ function UnitRangeTransform (l:: Int , dims:: Int , unit:: Bool , min:: U , max:: U ) where {T, U <: AbstractVector{T} }
218214 lenmin = length (min)
219215 lenmax = length (max)
220216 lenmin == l || lenmin == 0 || throw (DimensionMismatch (" Inconsistent dimensions." ))
221217 lenmax == l || lenmax == 0 || throw (DimensionMismatch (" Inconsistent dimensions." ))
222- new {T} (l, dims, unit, min, max)
218+ new {T, U } (l, dims, unit, min, max)
223219 end
224220end
225221
@@ -270,45 +266,32 @@ function fit(::Type{UnitRangeTransform}, X::AbstractMatrix{<:Real};
270266 Base. depwarn (" fit(t, x) is deprecated: use fit(t, x, dims=2) instead" , :fit )
271267 dims = 2
272268 end
273- if dims == 1
274- l, tmin, tmax = _compute_extrema (X)
275- elseif dims == 2
276- l, tmin, tmax = _compute_extrema (X' )
277- else
278- throw (DomainError (dims, " fit only accept dims to be 1 or 2." ))
279- end
280-
281- for i = 1 : l
282- @inbounds tmax[i] = 1 / (tmax[i] - tmin[i])
283- end
269+ dims ∈ (1 , 2 ) || throw (DomainError (dims, " fit only accept dims to be 1 or 2." ))
270+ tmin, tmax = _compute_extrema (X, dims)
271+ @. tmax = 1 / (tmax - tmin)
272+ l = length (tmin)
284273 return UnitRangeTransform (l, dims, unit, tmin, tmax)
285274end
286275
287- function _compute_extrema (X:: AbstractMatrix{<:Real} )
288- n, l = size (X)
289- tmin = X[1 , :]
290- tmax = X[1 , :]
291- for j = 1 : l
292- @inbounds for i = 2 : n
293- if X[i, j] < tmin[j]
294- tmin[j] = X[i, j]
295- elseif X[i, j] > tmax[j]
296- tmax[j] = X[i, j]
297- end
298- end
276+ function _compute_extrema (X:: AbstractMatrix , dims:: Integer )
277+ dims == 2 && return _compute_extrema (X' , 1 )
278+ l = size (X, 2 )
279+ tmin = similar (X, l)
280+ tmax = similar (X, l)
281+ for i in 1 : l
282+ @inbounds tmin[i], tmax[i] = extrema (@view (X[:, i]))
299283 end
300- return l, tmin, tmax
284+ return tmin, tmax
301285end
302286
303287function fit (:: Type{UnitRangeTransform} , X:: AbstractVector{<:Real} ;
304288 dims:: Integer = 1 , unit:: Bool = true )
305289 if dims != 1
306290 throw (DomainError (dims, " fit only accept dims=1 over a vector. Try fit(t, x, dims=1)." ))
307291 end
308-
309- l, tmin, tmax = _compute_extrema (reshape (X, :, 1 ))
292+ tmin, tmax = extrema (X)
310293 tmax = 1 / (tmax - tmin)
311- return UnitRangeTransform (1 , dims, unit, vec ( tmin), vec ( tmax) )
294+ return UnitRangeTransform (1 , dims, unit, [ tmin], [ tmax] )
312295end
313296
314297function transform! (y:: AbstractMatrix{<:Real} , t:: UnitRangeTransform , x:: AbstractMatrix{<:Real} )
0 commit comments