diff --git a/Project.toml b/Project.toml index 4f19c399..aba30825 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StringEncodings = "69024149-9ee7-55f6-a4c4-859efe599b68" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ThreadPools = "b189fb0b-2eb5-4ed4-bc0c-d34c51242431" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" diff --git a/example/CompareSaveMethods.jl b/example/CompareSaveMethods.jl new file mode 100644 index 00000000..9d38c032 --- /dev/null +++ b/example/CompareSaveMethods.jl @@ -0,0 +1,66 @@ +using MPIMeasurements +using FFTW +using Statistics + +println("="^70) +println("MPI Data Storage Methods Comparison") +println("="^70) + +config = (numSamples=1632, numChannels=3, numPeriods=500, numFrames=100, + numPeriodGrouping=5, selectedFrequencies=50) +testData = randn(Float32, config.numSamples, config.numChannels, config.numPeriods, config.numFrames) + +println("\nMethod 1: Traditional (Full Time Domain)") +println("-"^70) +traditionalSize = sizeof(testData) +println("Storage: $(round(traditionalSize / 1024^2, digits=2)) MB") +println("Post-processing: Period grouping, FFT, frequency filtering needed") + +println("\nMethod 2: Frequency Filtered During Acquisition") +println("-"^70) +mutable struct FilteredStorage <: StorageBuffer + frequencyData::Vector{Any} +end +FilteredStorage() = FilteredStorage(Any[]) +Base.push!(buffer::FilteredStorage, data) = (push!(buffer.frequencyData, data); (start=1, stop=size(data, 4))) +MPIMeasurements.sinks!(buffer::FilteredStorage, sinks::Vector{SinkBuffer}) = sinks + +filteredBuffer = FilteredStorage() +frequencyIndices = collect(1:config.selectedFrequencies) +rfftBuffer = RFFTBuffer(filteredBuffer, frequencyIndices) +periodGroupingBuffer = PeriodGroupingBuffer(rfftBuffer, config.numPeriodGrouping) +push!(periodGroupingBuffer, testData) + +filteredSize = sizeof(filteredBuffer.frequencyData[1]) +println("Storage: $(round(filteredSize / 1024^2, digits=2)) MB") +println("Post-processing: None, ready for reconstruction") + +savings = traditionalSize - filteredSize +savingsPercent = (1 - filteredSize / traditionalSize) * 100 +compressionRatio = traditionalSize / filteredSize + +println("\n"*"="^70) +println("Analysis") +println("="^70) +println("Space saved: $(round(savings / 1024^2, digits=2)) MB ($(round(savingsPercent, digits=2))%)") +println("Compression ratio: $(round(compressionRatio, digits=1))×") + +positionsInScan = 3000 +totalFrames = config.numFrames * positionsInScan +traditionalTotalSize = traditionalSize / config.numFrames * totalFrames +filteredTotalSize = filteredSize / config.numFrames * totalFrames +totalSavings = traditionalTotalSize - filteredTotalSize + +println("\nFull 3D Scan Projection ($positionsInScan positions):") +println(" Traditional: $(round(traditionalTotalSize / 1024^3, digits=2)) GB") +println(" Filtered: $(round(filteredTotalSize / 1024^3, digits=2)) GB") +println(" Savings: $(round(totalSavings / 1024^3, digits=2)) GB") + +archiveUploadSpeed = 100 * 1024^2 +traditionalUploadTime = traditionalTotalSize / archiveUploadSpeed +filteredUploadTime = filteredTotalSize / archiveUploadSpeed +println("\nNetwork Transfer (100 MB/s):") +println(" Traditional: $(round(traditionalUploadTime / 60, digits=1)) min") +println(" Filtered: $(round(filteredUploadTime / 60, digits=1)) min") +println(" Time saved: $(round((traditionalUploadTime - filteredUploadTime) / 60, digits=1)) min") +println("="^70) diff --git a/example/FrequencyFilteringDemo.jl b/example/FrequencyFilteringDemo.jl new file mode 100644 index 00000000..96374449 --- /dev/null +++ b/example/FrequencyFilteringDemo.jl @@ -0,0 +1,101 @@ +using MPIMeasurements +using FFTW +using Statistics + +println("="^70) +println("Frequency Filtering Demonstration") +println("="^70) + +numSamples, numChannels, numPeriods, numFrames = 1632, 3, 500, 100 +numPeriodGrouping = 5 + +function createSyntheticMPIData(samples, channels, periods, frames) + data = zeros(Float32, samples, channels, periods, frames) + for freq in [1, 2, 3, 5, 7, 11] + amplitude = 1.0 / freq + for t in 1:samples + data[t, :, :, :] .+= amplitude * sin(2π * freq * (t-1) / samples) + end + end + data .+= 0.1f0 * randn(Float32, size(data)...) + return data +end + +originalData = createSyntheticMPIData(numSamples, numChannels, numPeriods, numFrames) +originalSize = sizeof(originalData) +println("\nOriginal: $(size(originalData)), $(round(originalSize / 1024^2, digits=2)) MB") + +println("\nApplying frequency filtering...") + +selectedFrequencies = 1:50 +numFrequenciesAfterGrouping = div(numSamples * numPeriodGrouping, 2) + 1 +println(" Selected: $(length(selectedFrequencies)) / $numFrequenciesAfterGrouping frequencies") + +mutable struct CaptureBuffer <: StorageBuffer + data::Vector{Any} +end +CaptureBuffer() = CaptureBuffer(Any[]) +Base.push!(buffer::CaptureBuffer, data) = (push!(buffer.data, data); (start=1, stop=size(data, 4))) +MPIMeasurements.sinks!(buffer::CaptureBuffer, sinks::Vector{SinkBuffer}) = sinks + +captureBuffer = CaptureBuffer() +rfftBuffer = RFFTBuffer(captureBuffer, collect(selectedFrequencies)) +periodGroupingBuffer = PeriodGroupingBuffer(rfftBuffer, numPeriodGrouping) +push!(periodGroupingBuffer, originalData) + +filteredData = captureBuffer.data[1] +filteredSize = sizeof(filteredData) + +println("\nFiltered: $(size(filteredData)), $(round(filteredSize / 1024^2, digits=2)) MB") +reductionPercent = (1 - filteredSize / originalSize) * 100 +compressionRatio = originalSize / filteredSize + +println("\n"*"="^70) +println("Results") +println("="^70) +println("Data reduction: $(round(reductionPercent, digits=2))%") +println("Compression ratio: $(round(compressionRatio, digits=1))×") +println("Storage saved: $(round((originalSize - filteredSize) / 1024^2, digits=2)) MB") + +fullSpectrum = abs.(rfft(originalData[:, 1, 1, 1])) +filteredSpectrum = abs.(filteredData[:, 1, 1, 1]) +println("\nSpectrum analysis:") +println(" Full length: $(length(fullSpectrum)), strongest at $(argmax(fullSpectrum))") +println(" Filtered length: $(length(filteredSpectrum)), strongest at $(argmax(filteredSpectr um))") +println("="^70) + +println(" Filtered data shape: $(size(filteredData))") +println(" Filtered data size: $(round(filteredSize / 1024^2, digits=2)) MB") +println(" Data type: $(eltype(filteredData))") + +# ============================================================================ +# Part 3: Calculate Savings +# ============================================================================ + +println("\n[3/4] Analyzing storage efficiency...") + +reductionPercent = (1 - filteredSize / originalSize) * 100 +compressionRatio = originalSize / filteredSize + +println(" ✓ Data reduction: $(round(reductionPercent, digits=2))%") +println(" ✓ Compression ratio: $(round(compressionRatio, digits=1))×") +println(" ✓ Storage saved: $(round((originalSize - filteredSize) / 1024^2, digits=2)) MB") + +# ============================================================================ +# Part 4: Frequency Spectrum Analysis +# ============================================================================ + +println("\n[4/4] Frequency spectrum analysis...") + +# Compute full spectrum for comparison +fullSpectrum = rfft(originalData[:, 1, 1, 1]) # First channel, first period, first frame +fullSpectrumMagnitude = abs.(fullSpectrum) + +# Filtered spectrum (what we actually store) +filteredSpectrum = filteredData[:, 1, 1, 1] # Same selection +filteredSpectrumMagnitude = abs.(filteredSpectrum) + +println(" Full spectrum length: $(length(fullSpectrum))") +println(" Filtered spectrum length: $(length(filteredSpectrum))") +println(" Strongest at: $(argmax(filteredSpectrum))") +println("="^70) diff --git a/example/VisualizeFilteringPipeline.jl b/example/VisualizeFilteringPipeline.jl new file mode 100644 index 00000000..675309ce --- /dev/null +++ b/example/VisualizeFilteringPipeline.jl @@ -0,0 +1,73 @@ +using MPIMeasurements +using FFTW +using Statistics + +println("="^70) +println("Frequency Filtering Pipeline Visualization") +println("="^70) + +numSamples, numChannels, numPeriods, numFrames = 64, 3, 12, 4 +numPeriodGrouping, selectedFrequencies = 3, [1, 3, 5, 7, 9, 11, 13, 15] + +t = range(0, 2π, length=numSamples) +signal = sin.(t .* 3) .+ 0.5 .* sin.(t .* 7) .+ 0.2 .* randn(numSamples) +inputData = zeros(Float32, numSamples, numChannels, numPeriods, numFrames) +for c in 1:numChannels, p in 1:numPeriods, f in 1:numFrames + inputData[:, c, p, f] .= signal .* (1 + 0.1 * randn()) +end + +println("\n"*"="^70) +println("Stage 1: Time Domain Input") +println("="^70) +println("Shape: $(size(inputData))") +println("Memory: $(sizeof(inputData)) bytes") +println("Type: $(eltype(inputData))") + +println("\n"*"="^70) +println("Stage 2: Period Grouping (×$numPeriodGrouping)") +println("="^70) +tmp = permutedims(inputData, (1, 3, 2, 4)) +tmp2 = reshape(tmp, numSamples * numPeriodGrouping, div(numPeriods, numPeriodGrouping), numChannels, numFrames) +groupedData = permutedims(tmp2, (1, 3, 2, 4)) +println("Shape: $(size(inputData)) → $(size(groupedData))") +println("Memory: $(sizeof(groupedData)) bytes") + +println("\n"*"="^70) +println("Stage 3: Real FFT") +println("="^70) +fftData = rfft(groupedData, 1) +numFrequencies = size(fftData, 1) +println("Shape: $(size(groupedData)) → $(size(fftData))") +println("Memory: $(sizeof(fftData)) bytes") +println("Type: $(eltype(fftData))") + +spectrum = abs.(fftData[:, 1, 1, 1]) +topFreqs = sortperm(spectrum, rev=true)[1:3] +println("Top frequencies: $(topFreqs)") + +println("\n"*"="^70) +println("Stage 4: Frequency Selection") +println("="^70) +filteredData = fftData[selectedFrequencies, :, :, :] +println("Shape: $(size(fftData)) → $(size(filteredData))") +println("Memory: $(sizeof(fftData)) → $(sizeof(filteredData)) bytes") +println("Reduction: $(round((1 - sizeof(filteredData) / sizeof(fftData)) * 100, digits=1))%") + +println("\n"*"="^70) +println("Stage 5: MDF Storage") +println("="^70) +println("Shape: $(size(filteredData))") +println("Type: $(eltype(filteredData))") +println("MDF metadata: isFourierTransformed=true, isFrequencySelection=true") + +totalReduction = (1 - sizeof(filteredData) / sizeof(inputData)) * 100 +compressionRatio = sizeof(inputData) / sizeof(filteredData) + +println("\n"*"="^70) +println("Pipeline Summary") +println("="^70) +println("Original: $(sizeof(inputData)) bytes") +println("Filtered: $(sizeof(filteredData)) bytes") +println("Reduction: $(round(totalReduction, digits=2))%") +println("Compression: $(round(compressionRatio, digits=1))×") +println("="^70) diff --git a/src/Protocols/Storage/ChainableBuffer.jl b/src/Protocols/Storage/ChainableBuffer.jl index 4f9e34a8..6c838a2e 100644 --- a/src/Protocols/Storage/ChainableBuffer.jl +++ b/src/Protocols/Storage/ChainableBuffer.jl @@ -1,3 +1,6 @@ +# Export new frequency filtering buffers +export PeriodGroupingBuffer, RFFTBuffer + mutable struct AverageBuffer{T} <: IntermediateBuffer where {T<:Number} target::StorageBuffer buffer::Array{T,4} @@ -221,4 +224,45 @@ function TxDAQControllerBuffer(tx::TxDAQController, sequence::ControlSequence) end update!(buffer::TxDAQControllerBuffer, start, stop) = insert!(buffer, calcControlMatrix(buffer.tx.cont), start, stop) insert!(buffer::TxDAQControllerBuffer, applied::Matrix{ComplexF64}, start, stop) = buffer.applied[:, :, :, start:stop] .= applied -read(buffer::TxDAQControllerBuffer) = buffer.applied \ No newline at end of file +read(buffer::TxDAQControllerBuffer) = buffer.applied + +mutable struct PeriodGroupingBuffer{T} <: IntermediateBuffer where {T<:Number} + target::StorageBuffer + numGrouping::Int +end +PeriodGroupingBuffer(buffer::StorageBuffer, numGrouping::Int) = PeriodGroupingBuffer{Float32}(buffer, numGrouping) + +function push!(buffer::PeriodGroupingBuffer{T}, frames::AbstractArray{T,4}) where {T<:Number} + if buffer.numGrouping == 1 + return push!(buffer.target, frames) + end + + numSamples, numChannels, numPeriods, numFrames = size(frames) + + if mod(numPeriods, buffer.numGrouping) != 0 + error("Periods cannot be grouped: $numPeriods periods cannot be divided by $(buffer.numGrouping)") + end + + tmp = permutedims(frames, (1, 3, 2, 4)) + newNumPeriods = div(numPeriods, buffer.numGrouping) + tmp2 = reshape(tmp, numSamples * buffer.numGrouping, newNumPeriods, numChannels, numFrames) + result = permutedims(tmp2, (1, 3, 2, 4)) + + return push!(buffer.target, result) +end + +sinks!(buffer::PeriodGroupingBuffer, sinks::Vector{SinkBuffer}) = sinks!(buffer.target, sinks) + +mutable struct RFFTBuffer{T} <: IntermediateBuffer where {T<:Complex} + target::StorageBuffer + frequencyMask::Union{Vector{Int}, Nothing} +end +RFFTBuffer(buffer::StorageBuffer, frequencyMask::Union{Vector{Int}, Nothing} = nothing) = RFFTBuffer{ComplexF32}(buffer, frequencyMask) + +function push!(buffer::RFFTBuffer{T}, frames::AbstractArray{<:Real,4}) where {T<:Complex} + dataFD = rfft(frames, 1) + result = isnothing(buffer.frequencyMask) ? dataFD : dataFD[buffer.frequencyMask, :, :, :] + return push!(buffer.target, result) +end + +sinks!(buffer::RFFTBuffer, sinks::Vector{SinkBuffer}) = sinks!(buffer.target, sinks) \ No newline at end of file diff --git a/src/Protocols/Storage/MDF.jl b/src/Protocols/Storage/MDF.jl index c3b2271a..256d6b35 100644 --- a/src/Protocols/Storage/MDF.jl +++ b/src/Protocols/Storage/MDF.jl @@ -1,5 +1,5 @@ -function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence::Sequence, data::Array{Float32,4}, isBackgroundFrame::Vector{Bool}, mdf::MDFv2InMemory;temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing) +function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence::Sequence, data::Union{Array{Float32,4}, Array{ComplexF32,4}}, isBackgroundFrame::Vector{Bool}, mdf::MDFv2InMemory;temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing, frequencies::Union{Vector{Int}, Nothing}=nothing, isFourierTransformed::Bool=false) if !ismissing(studyName(mdf)) name = studyName(mdf) else @@ -19,7 +19,7 @@ function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence:: fillMDFScanner(mdf, scanner) fillMDFTracer(mdf) - fillMDFMeasurement(mdf, data, isBackgroundFrame, temperatures = temperatures, drivefield = drivefield, applied = applied) + fillMDFMeasurement(mdf, data, isBackgroundFrame, temperatures = temperatures, drivefield = drivefield, applied = applied, frequencies = frequencies, isFourierTransformed = isFourierTransformed) fillMDFAcquisition(mdf, scanner, sequence) filename = getNewExperimentPath(study) @@ -27,7 +27,7 @@ function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence:: return saveasMDF(filename, mdf) end -function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence::Sequence, data::Array{Float32,4}, mdf::MDFv2InMemory; bgdata::Union{Array{Float32,4}, Nothing}=nothing, temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing) +function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence::Sequence, data::Union{Array{Float32,4}, Array{ComplexF32,4}}, mdf::MDFv2InMemory; bgdata::Union{Array{Float32,4}, Array{ComplexF32,4}, Nothing}=nothing, temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing, frequencies::Union{Vector{Int}, Nothing}=nothing, isFourierTransformed::Bool=false) if !ismissing(studyName(mdf)) name = studyName(mdf) else @@ -47,7 +47,7 @@ function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence:: fillMDFScanner(mdf, scanner) fillMDFTracer(mdf) - fillMDFMeasurement(mdf, sequence, data, bgdata, temperatures = temperatures, drivefield = drivefield, applied = applied) + fillMDFMeasurement(mdf, sequence, data, bgdata, temperatures = temperatures, drivefield = drivefield, applied = applied, frequencies = frequencies, isFourierTransformed = isFourierTransformed) fillMDFAcquisition(mdf, scanner, sequence) filename = getNewExperimentPath(study) @@ -57,8 +57,8 @@ end -function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence::Sequence, data::Array{Float32,4}, - positions::Union{Positions, AbstractArray}, isBackgroundFrame::Vector{Bool}, mdf::MDFv2InMemory; storeAsSystemMatrix::Bool = false, deltaSampleSize::Union{Vector{typeof(1.0u"m")}, Nothing} = nothing, temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing) +function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence::Sequence, data::Union{Array{Float32,4}, Array{ComplexF32,4}}, + positions::Union{Positions, AbstractArray}, isBackgroundFrame::Vector{Bool}, mdf::MDFv2InMemory; storeAsSystemMatrix::Bool = false, deltaSampleSize::Union{Vector{typeof(1.0u"m")}, Nothing} = nothing, temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing, frequencies::Union{Vector{Int}, Nothing}=nothing, isFourierTransformed::Bool=false) if storeAsSystemMatrix study = MPIFiles.getCalibStudy(store) @@ -85,7 +85,7 @@ function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence:: @debug isBackgroundFrame - fillMDFMeasurement(mdf, data, isBackgroundFrame, temperatures = temperatures, drivefield = drivefield, applied = applied) + fillMDFMeasurement(mdf, data, isBackgroundFrame, temperatures = temperatures, drivefield = drivefield, applied = applied, frequencies = frequencies, isFourierTransformed = isFourierTransformed) fillMDFAcquisition(mdf, scanner, sequence) fillMDFCalibration(mdf, positions, deltaSampleSize = deltaSampleSize) @@ -94,8 +94,6 @@ function MPIFiles.saveasMDF(store::DatasetStore, scanner::MPIScanner, sequence:: return saveasMDF(filename, mdf) end - - function fillMDFCalibration(mdf::MDFv2InMemory, positions::GridPositions; deltaSampleSize::Union{Vector{typeof(1.0u"m")}, Nothing} = nothing) # /calibration/ subgroup @@ -245,14 +243,14 @@ function fillMDFTracer(mdf::MDFv2InMemory) return end -function fillMDFMeasurement(mdf::MDFv2InMemory, sequence::Sequence, data::Array{Float32,4}, - bgdata::Nothing; temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing, bgDriveField::Nothing=nothing, bgTransmit::Nothing=nothing) +function fillMDFMeasurement(mdf::MDFv2InMemory, sequence::Sequence, data::Union{Array{Float32,4}, Array{ComplexF32,4}}, + bgdata::Nothing; temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing, bgDriveField::Nothing=nothing, bgTransmit::Nothing=nothing, frequencies::Union{Vector{Int}, Nothing}=nothing, isFourierTransformed::Bool=false) numFrames = acqNumFrames(sequence) isBackgroundFrame = zeros(Bool, numFrames) - return fillMDFMeasurement(mdf, data, isBackgroundFrame, temperatures = temperatures, drivefield = drivefield, applied = applied) + return fillMDFMeasurement(mdf, data, isBackgroundFrame, temperatures = temperatures, drivefield = drivefield, applied = applied, frequencies = frequencies, isFourierTransformed = isFourierTransformed) end -function fillMDFMeasurement(mdf::MDFv2InMemory, sequence::Sequence, data::Array{Float32,4}, - bgdata::Union{Array{Float32}}; temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing, bgDriveField::Union{Array{ComplexF64}, Nothing}=nothing, bgTransmit::Union{Array{ComplexF64}, Nothing}=nothing) +function fillMDFMeasurement(mdf::MDFv2InMemory, sequence::Sequence, data::Union{Array{Float32,4}, Array{ComplexF32,4}}, + bgdata::Union{Array{Float32,4}, Array{ComplexF32,4}}; temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing, bgDriveField::Union{Array{ComplexF64}, Nothing}=nothing, bgTransmit::Union{Array{ComplexF64}, Nothing}=nothing, frequencies::Union{Vector{Int}, Nothing}=nothing, isFourierTransformed::Bool=false) # /measurement/ subgroup numFrames = acqNumFrames(sequence) numBGFrames = size(bgdata,4) @@ -267,11 +265,11 @@ function fillMDFMeasurement(mdf::MDFv2InMemory, sequence::Sequence, data::Array{ end isBackgroundFrame = cat(ones(Bool,numBGFrames), zeros(Bool,numFrames), dims=1) numFrames = numFrames + numBGFrames - return fillMDFMeasurement(mdf, data_, isBackgroundFrame, temperatures = temperatures, drivefield = drivefield_, applied = applied_) + return fillMDFMeasurement(mdf, data_, isBackgroundFrame, temperatures = temperatures, drivefield = drivefield_, applied = applied_, frequencies = frequencies, isFourierTransformed = isFourierTransformed) end -function fillMDFMeasurement(mdf::MDFv2InMemory, data::Array{Float32}, isBackgroundFrame::Vector{Bool}; temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing) +function fillMDFMeasurement(mdf::MDFv2InMemory, data::Array{Float32}, isBackgroundFrame::Vector{Bool}; temperatures::Union{Array{Float32}, Nothing}=nothing, drivefield::Union{Array{ComplexF64}, Nothing}=nothing, applied::Union{Array{ComplexF64}, Nothing}=nothing, frequencies::Union{Vector{Int}, Nothing}=nothing) # /measurement/ subgroup numFrames = size(data, 4) @@ -297,6 +295,47 @@ function fillMDFMeasurement(mdf::MDFv2InMemory, data::Array{Float32}, isBackgrou return end +# New version with frequency filtering support +function fillMDFMeasurement(mdf::MDFv2InMemory, data::Union{Array{Float32}, Array{ComplexF32}}, isBackgroundFrame::Vector{Bool}; + temperatures::Union{Array{Float32}, Nothing}=nothing, + drivefield::Union{Array{ComplexF64}, Nothing}=nothing, + applied::Union{Array{ComplexF64}, Nothing}=nothing, + frequencies::Union{Vector{Int}, Nothing}=nothing, + isFourierTransformed::Bool=false) + # /measurement/ subgroup + # Supports both time domain (Float32) and frequency domain (ComplexF32) data + numFrames = size(data, 4) + + measData(mdf, data) + measIsBackgroundCorrected(mdf, false) + measIsBackgroundFrame(mdf, isBackgroundFrame) + measIsFastFrameAxis(mdf, false) + measIsFourierTransformed(mdf, isFourierTransformed) + measIsFramePermutation(mdf, false) + measIsSparsityTransformed(mdf, false) + measIsSpectralLeakageCorrected(mdf, false) + measIsTransferFunctionCorrected(mdf, false) + + # Handle frequency selection + if !isnothing(frequencies) && isFourierTransformed + measIsFrequencySelection(mdf, true) + MPIFiles.measFrequencySelection(mdf, frequencies) + else + measIsFrequencySelection(mdf, false) + end + + if !isnothing(temperatures) + MPIFiles.measTemperatures(mdf, temperatures) + end + if !isnothing(drivefield) + MPIFiles.measObservedDriveField(mdf, drivefield) + end + if !isnothing(applied) + MPIFiles.measAppliedDriveField(mdf, applied) + end + return +end + function fillMDFAcquisition(mdf::MDFv2InMemory, scanner::MPIScanner, sequence::Sequence) # Needs to be filled after(!) measurement group diff --git a/src/Protocols/Storage/MeasurementState.jl b/src/Protocols/Storage/MeasurementState.jl index 8c569c2b..8b84e693 100644 --- a/src/Protocols/Storage/MeasurementState.jl +++ b/src/Protocols/Storage/MeasurementState.jl @@ -1,3 +1,6 @@ +# Export buffer abstract types for testing and extension +export StorageBuffer, IntermediateBuffer, SinkBuffer, SequenceBuffer, DeviceBuffer + abstract type MeasurementState end abstract type StorageBuffer end diff --git a/test/Protocols/BufferTests.jl b/test/Protocols/BufferTests.jl new file mode 100644 index 00000000..446a04b7 --- /dev/null +++ b/test/Protocols/BufferTests.jl @@ -0,0 +1,121 @@ +using Test +using FFTW +using Statistics +using MPIMeasurements + +mutable struct MockStorageBuffer <: StorageBuffer + data::Vector{Any} +end +MockStorageBuffer() = MockStorageBuffer(Any[]) +Base.push!(buffer::MockStorageBuffer, data) = (push!(buffer.data, data); return (start=1, stop=size(data, 4))) +MPIMeasurements.sinks!(buffer::MockStorageBuffer, sinks::Vector{SinkBuffer}) = sinks + +@testset "PeriodGroupingBuffer Tests" begin + @testset "Basic period grouping with numGrouping=2" begin + mockTarget = MockStorageBuffer() + buffer = PeriodGroupingBuffer(mockTarget, 2) + testData = randn(Float32, 8, 3, 4, 2) + push!(buffer, testData) + @test length(mockTarget.data) == 1 + @test size(mockTarget.data[1]) == (16, 3, 2, 2) + end + + @testset "Period grouping with numGrouping=1 (pass-through)" begin + mockTarget = MockStorageBuffer() + buffer = PeriodGroupingBuffer(mockTarget, 1) + testData = randn(Float32, 8, 3, 4, 2) + push!(buffer, testData) + @test size(mockTarget.data[1]) == size(testData) + @test mockTarget.data[1] ≈ testData + end + + @testset "Period grouping with non-divisible periods should error" begin + mockTarget = MockStorageBuffer() + buffer = PeriodGroupingBuffer(mockTarget, 3) + testData = randn(Float32, 8, 3, 5, 2) + @test_throws ErrorException push!(buffer, testData) + end + + @testset "Period grouping matches MPIFiles getMeasurements logic" begin + mockTarget = MockStorageBuffer() + numPeriodGrouping, numSamples, numChannels, numPeriods, numFrames = 3, 12, 3, 9, 4 + buffer = PeriodGroupingBuffer(mockTarget, numPeriodGrouping) + testData = randn(Float32, numSamples, numChannels, numPeriods, numFrames) + push!(buffer, testData) + result = mockTarget.data[1] + + @test size(result) == (numSamples * numPeriodGrouping, numChannels, div(numPeriods, numPeriodGrouping), numFrames) + + tmp = permutedims(testData, (1, 3, 2, 4)) + tmp2 = reshape(tmp, numSamples * numPeriodGrouping, div(numPeriods, numPeriodGrouping), numChannels, numFrames) + expected = permutedims(tmp2, (1, 3, 2, 4)) + @test result ≈ expected + end +end + +@testset "RFFTBuffer Tests" begin + @testset "Basic RFFT without frequency selection" begin + mockTarget = MockStorageBuffer() + buffer = RFFTBuffer(mockTarget, nothing) + testData = randn(Float32, 16, 3, 2, 4) + push!(buffer, testData) + @test length(mockTarget.data) == 1 + @test size(mockTarget.data[1]) == (9, 3, 2, 4) + @test eltype(mockTarget.data[1]) <: Complex + end + + @testset "RFFT with frequency selection" begin + mockTarget = MockStorageBuffer() + buffer = RFFTBuffer(mockTarget, [1, 3, 5, 7]) + testData = randn(Float32, 16, 3, 2, 4) + push!(buffer, testData) + @test size(mockTarget.data[1]) == (4, 3, 2, 4) + end + + @testset "RFFT matches FFTW.rfft behavior" begin + mockTarget = MockStorageBuffer() + buffer = RFFTBuffer(mockTarget, nothing) + testData = randn(Float32, 32, 2, 3, 5) + push!(buffer, testData) + @test mockTarget.data[1] ≈ rfft(testData, 1) + end + + @testset "RFFT with frequency selection matches indexing" begin + mockTarget = MockStorageBuffer() + frequencyMask = [2, 4, 6, 8, 10] + buffer = RFFTBuffer(mockTarget, frequencyMask) + testData = randn(Float32, 32, 2, 3, 5) + push!(buffer, testData) + @test mockTarget.data[1] ≈ rfft(testData, 1)[frequencyMask, :, :, :] + end +end + +@testset "Combined PeriodGrouping + RFFT Pipeline" begin + @testset "Period grouping followed by RFFT" begin + mockTarget = MockStorageBuffer() + rfftBuffer = RFFTBuffer(mockTarget, nothing) + periodBuffer = PeriodGroupingBuffer(rfftBuffer, 2) + testData = randn(Float32, 16, 3, 4, 2) + push!(periodBuffer, testData) + @test length(mockTarget.data) == 1 + @test size(mockTarget.data[1]) == (17, 3, 2, 2) + @test eltype(mockTarget.data[1]) <: Complex + end + + @testset "Period grouping + RFFT + frequency selection" begin + mockTarget = MockStorageBuffer() + frequencyMask = [1, 5, 9, 13] + rfftBuffer = RFFTBuffer(mockTarget, frequencyMask) + periodBuffer = PeriodGroupingBuffer(rfftBuffer, 3) + testData = randn(Float32, 24, 3, 9, 4) + push!(periodBuffer, testData) + result = mockTarget.data[1] + @test size(result) == (4, 3, 3, 4) + + tmp = permutedims(testData, (1, 3, 2, 4)) + tmp2 = reshape(tmp, 72, 3, 3, 4) + grouped = permutedims(tmp2, (1, 3, 2, 4)) + expected = rfft(grouped, 1)[frequencyMask, :, :, :] + @test result ≈ expected + end +end diff --git a/test/run_tests.jl b/test/run_tests.jl new file mode 100644 index 00000000..cfd91409 --- /dev/null +++ b/test/run_tests.jl @@ -0,0 +1,21 @@ +using Pkg + +println("="^70) +println("Running Frequency Filtering Tests") +println("="^70) + +Pkg.activate(".") +using MPIMeasurements + +required_packages = ["Test", "FFTW", "Statistics"] +for pkg in required_packages + if !haskey(Pkg.project().dependencies, pkg) + Pkg.add(pkg) + end +end + +include("Protocols/BufferTests.jl") + +println("="^70) +println("All tests passed!") +println("="^70)