Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions src/multitaper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,20 @@ struct MTCrossSpectraConfig{T,T1,T2,T3,T4,F,T5,T6,C<:MTConfig{T}}
function MTCrossSpectraConfig{T,T1,T2,T3,T4,F,T5,T6,C}(
n_channels::Int, normalization_weights::T1, x_mt::T2,
demean::Bool, mean_per_channel::T3, demeaned_signal::T4, freq::F, freq_range::T5,
freq_inds::T6, ensure_aligned::Bool, mt_config::C
freq_inds::T6, ensure_aligned::Bool, mt_config::C, override_warn::Bool
) where {T,T1,T2,T3,T4,F,T5,T6,C<:MTConfig{T}}
check_onesided_real(mt_config) # this restriction is artificial; the code needs to be generalized
if n_channels > mt_config.n_samples && !override_warn
Base.depwarn(
"""
n_channels > n_samples; this is likely a mistake.
From v0.9 onwards, each column is interpreted as a separate signal,
whereas in v0.8 and earlier, each row was interpreted as a separate signal.
To suppress this warning, pass `override_warn=true` as a keyword argument to
`MTCoherenceConfig`, `MTCrossSpectraConfig`, `mt_coherence`, or `mt_cross_power_spectra`.
""",
:mt_cross_power_spectra; force=true)
end
return new{T,T1,T2,T3,T4,F,T5,T6,C}(
n_channels, normalization_weights, x_mt,
demean, mean_per_channel, demeaned_signal, freq, freq_range,
Expand All @@ -447,12 +458,12 @@ struct MTCrossSpectraConfig{T,T1,T2,T3,T4,F,T5,T6,C<:MTConfig{T}}
end
function MTCrossSpectraConfig(n_channels::Int, normalization_weights::T1, x_mt::T2,
demean::Bool, mean_per_channel::T3, demeaned_signal::T4, freq::F, freq_range::T5,
freq_inds::T6, ensure_aligned::Bool, mt_config::C
freq_inds::T6, ensure_aligned::Bool, mt_config::C, override_warn::Bool=false
) where {T,T1,T2,T3,T4,F,T5,T6,C<:MTConfig{T}}
MTCrossSpectraConfig{T,T1,T2,T3,T4,F,T5,T6,C}(
n_channels, normalization_weights, x_mt,
demean, mean_per_channel, demeaned_signal, freq, freq_range,
freq_inds, ensure_aligned, mt_config
freq_inds, ensure_aligned, mt_config, override_warn
)
end
end
Expand All @@ -477,22 +488,23 @@ Returns a `CrossPowerSpectra` object.
function MTCrossSpectraConfig{T}(n_channels, n_samples; fs=1, demean=false,
freq_range=nothing,
ensure_aligned = T == Float32 || T == Complex{Float32},
override_warn=false,
kwargs...) where {T}
mt_config = MTConfig{T}(n_samples; fs, kwargs...)
return MTCrossSpectraConfig{T}(n_channels, mt_config; demean, freq_range, ensure_aligned)
return MTCrossSpectraConfig{T}(n_channels, mt_config; demean, freq_range, ensure_aligned, override_warn)
end

# extra method to ensure it's ok to pass the redundant type parameter {T}
MTCrossSpectraConfig{T}(n_channels, mt_config::MTConfig{T}; kwargs...) where {T} =
MTCrossSpectraConfig(n_channels, mt_config; kwargs...)

function MTCrossSpectraConfig(n_channels, mt_config::MTConfig{T}; demean=false,
freq_range=nothing, ensure_aligned = T == Float32 || T == Complex{Float32}) where {T}
function MTCrossSpectraConfig(n_channels, mt_config::MTConfig{T}; demean=false, freq_range=nothing,
ensure_aligned = T == Float32 || T == Complex{Float32}, override_warn=false) where {T}

n_samples = mt_config.n_samples
if demean
mean_per_channel = Vector{T}(undef, n_channels)
demeaned_signal = Matrix{T}(undef, n_channels, n_samples)
mean_per_channel = Matrix{T}(undef, 1, n_channels)
demeaned_signal = Matrix{T}(undef, n_samples, n_channels)
else
mean_per_channel = nothing
demeaned_signal = nothing
Expand All @@ -512,7 +524,7 @@ function MTCrossSpectraConfig(n_channels, mt_config::MTConfig{T}; demean=false,
end
return MTCrossSpectraConfig(n_channels, normalization_weights, x_mt, demean,
mean_per_channel, demeaned_signal, freq,
freq_range, freq_inds, ensure_aligned, mt_config)
freq_range, freq_inds, ensure_aligned, mt_config, override_warn)
end

function allocate_output(config::MTCrossSpectraConfig{T}) where {T}
Expand All @@ -529,7 +541,7 @@ end
Computes multitapered cross power spectra between channels of a signal. Arguments:

* `output`: `n_channels` x `n_channels` x `length(config.freq)`. Can be created by `DSP.allocate_output(config)`.
* `signal`: `n_channels` x `n_samples`
* `signal`: `n_samples` x `n_channels`
* `config`: `MTCrossSpectraConfig{T}`: optionally pass a [`MTCrossSpectraConfig`](@ref) to
preallocate temporary and choose configuration settings.
Otherwise, one may pass any keyword arguments accepted by this object.
Expand All @@ -542,19 +554,18 @@ See also [`mt_cross_power_spectra`](@ref) and [`MTCrossSpectraConfig`](@ref).
mt_cross_power_spectra!

function mt_cross_power_spectra!(output, signal::AbstractMatrix{T}; fs=1, kwargs...) where {T}
n_channels, n_samples = size(signal)
n_samples, n_channels = size(signal)
config = MTCrossSpectraConfig{T}(n_channels, n_samples; fs, fft_flags=FFTW.ESTIMATE,
kwargs...)
return mt_cross_power_spectra!(output, signal, config)
end

@views function mt_cross_power_spectra!(output, signal::AbstractMatrix,
config::MTCrossSpectraConfig)
function mt_cross_power_spectra!(output, signal::AbstractMatrix, config::MTCrossSpectraConfig)
n_chan = config.n_channels
n_samples = config.mt_config.n_samples
n_freqi = length(config.freq_inds)

if size(signal) != (n_chan, n_samples)
if size(signal) != (n_samples, n_chan)
throw(DimensionMismatch(lazy"Size of `signal` does not match `(config.n_channels, config.mt_config.n_samples)`;
got `size(signal)`=$(size(signal)) but `(config.n_channels, config.mt_config.n_samples)`=$((n_chan, n_samples))"))
end
Expand All @@ -576,26 +587,27 @@ end
else
mt_fft_tapered_multichannel!(x_mt, signal, config)
end
x_mt[1, :, :] ./= sqrt(2)
@views x_mt[1, :, :] ./= sqrt(2)
if iseven(config.mt_config.nfft)
x_mt[end, :, :] ./= sqrt(2)
@views x_mt[end, :, :] ./= sqrt(2)
end
cs_inner!(output, config.normalization_weights, x_mt, config)
return CrossPowerSpectra(output, config.freq)
end

function mt_fft_tapered_multichannel_ensure_aligned!(x_mt, signal, config)
fft_output = config.mt_config.fft_output_tmp
for k in 1:(config.n_channels), taper in 1:config.mt_config.ntapers
# we do this in two steps so that we are sure `fft_output` has the memory alignment FFTW expects (without needing the `FFTW.UNALIGNED` flag)
mt_fft_tapered!(fft_output, signal[k, :], taper, config.mt_config)
for k in 1:config.n_channels, taper in 1:config.mt_config.ntapers
# we do this in two steps to ensure `fft_output` has the memory alignment FFTW expects
# (without needing the `FFTW.UNALIGNED` flag)
mt_fft_tapered!(fft_output, view(signal, :, k), taper, config.mt_config)
x_mt[:, taper, k] .= fft_output
end
end

@views function mt_fft_tapered_multichannel!(x_mt, signal, config)
for k in 1:(config.n_channels), taper in 1:config.mt_config.ntapers
mt_fft_tapered!(x_mt[:, taper, k], signal[k, :], taper, config.mt_config)
mt_fft_tapered!(x_mt[:, taper, k], signal[:, k], taper, config.mt_config)
end
end

Expand Down Expand Up @@ -626,7 +638,7 @@ end

Computes multitapered cross power spectra between channels of a signal. Arguments:

* `signal`: `n_channels` x `n_samples`
* `signal`: `n_samples` x `n_channels`
* Optionally pass an [`MTCrossSpectraConfig`](@ref) object to preallocate temporary variables
and choose configuration settings. Otherwise, any keyword arguments accepted by [`MTCrossSpectraConfig`](@ref) may be passed here.

Expand All @@ -638,7 +650,7 @@ See also [`mt_cross_power_spectra!`](@ref) and [`MTCrossSpectraConfig`](@ref).
mt_cross_power_spectra

function mt_cross_power_spectra(signal::AbstractMatrix{T}; fs=1, kwargs...) where {T}
n_channels, n_samples = size(signal)
n_samples, n_channels = size(signal)
config = MTCrossSpectraConfig{T}(n_channels, n_samples; fs, fft_flags=FFTW.ESTIMATE,
kwargs...)
return mt_cross_power_spectra(signal, config)
Expand Down Expand Up @@ -768,7 +780,7 @@ function mt_coherence!(output, signal::AbstractMatrix,
n_samples = config.cs_config.mt_config.n_samples
n_freqs = length(config.cs_config.freq)

if size(signal) != (n_chan, n_samples)
if size(signal) != (n_samples, n_chan)
throw(DimensionMismatch(lazy"Size of `signal` does not match `(config.cs_config.n_channels, config.cs_config.mt_config.n_samples)`;
got `size(signal)`=$(size(signal)) but `(config.cs_config.n_channels, config.cs_config.mt_config.n_samples)`=$((n_chan, n_samples))"))
end
Expand All @@ -783,7 +795,7 @@ function mt_coherence!(output, signal::AbstractMatrix,
end

function mt_coherence!(output, signal::AbstractMatrix{T}; kwargs...) where {T}
n_channels, n_samples = size(signal)
n_samples, n_channels = size(signal)
config = MTCoherenceConfig{T}(n_channels, n_samples; fft_flags=FFTW.ESTIMATE, kwargs...)
return mt_coherence!(output, signal, config)
end
Expand All @@ -795,7 +807,7 @@ end

Arguments:

* `signal`: `n_channels` x `n_samples` matrix
* `signal`: `n_samples` x `n_channels` matrix
* Optionally pass an `MTCoherenceConfig` to pre-allocate temporary variables and choose configuration settings, otherwise, see [`MTCrossSpectraConfig`](@ref) for the meaning of the keyword arguments.

Returns a `Coherence` object.
Expand All @@ -805,7 +817,7 @@ See also [`mt_coherence`](@ref) and [`MTCoherenceConfig`](@ref).
mt_coherence

function mt_coherence(signal::AbstractMatrix{T}; kwargs...) where {T}
n_channels, n_samples = size(signal)
n_samples, n_channels = size(signal)
config = MTCoherenceConfig{T}(n_channels, n_samples; fft_flags=FFTW.ESTIMATE, kwargs...)
return mt_coherence(signal, config)
end
Expand Down
Loading
Loading