Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions doc/architectural_decisions/005-variance-addition.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Why do we add 0.5 to variance?

## Status

Current

## Context

If we pass counts data to `scipp` as variance this is correct- but if one of the counts is 0 then its variance is 0, leading to a division by 0 error when calculating its weight for fitting `weight = 1 / doc["data"][self.yerr]`.

## Decision

Our solution was to add 0.5 (`VARIANCE_ADDITION`) to each count to calculate variance. The actual data should be unchanged, the +0.5 is only for uncertainty calculation.

## Justification

The above approach is both "smooth" and converges towards sqrt(N) in the limit with high counts, and should also mean that we never get an uncertainty of zero in the fitting side.
60 changes: 31 additions & 29 deletions manual_system_tests/dae_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
import bluesky.plans as bp
import matplotlib
import matplotlib.pyplot as plt
from bluesky.callbacks import LiveTable
from bluesky.callbacks import LiveTable, LiveFitPlot
from bluesky.preprocessors import subs_decorator
from bluesky.utils import Msg
from ophyd_async.plan_stubs import ensure_connected

from ibex_bluesky_core.callbacks.file_logger import HumanReadableFileCallback
from ibex_bluesky_core.callbacks.fitting import LiveFit
from ibex_bluesky_core.callbacks.fitting.fitting_utils import Linear, Gaussian
from ibex_bluesky_core.callbacks.plotting import LivePlot
from ibex_bluesky_core.devices import get_pv_prefix
from ibex_bluesky_core.devices.block import block_rw_rbv
from ibex_bluesky_core.devices.block import block_rw_rbv, BlockWriteConfig
from ibex_bluesky_core.devices.simpledae import SimpleDae
from ibex_bluesky_core.devices.simpledae.controllers import (
RunPerPointController,
Expand Down Expand Up @@ -47,10 +49,12 @@ def dae_scan_plan() -> Generator[Msg, None, None]:
- The DAE waited for at least 500 good frames at each point
"""
prefix = get_pv_prefix()
block = block_rw_rbv(float, "mot")
block = block_rw_rbv(float, "mot", write_config=BlockWriteConfig(settle_time_s=0.5))
blocka = block_rw_rbv(float, "alice")
blockb = block_rw_rbv(float, "bob", write_config=BlockWriteConfig(settle_time_s=0.5))

controller = RunPerPointController(save_run=True)
waiter = GoodFramesWaiter(500)
waiter = GoodFramesWaiter(1)
reducer = GoodFramesNormalizer(
prefix=prefix,
detector_spectra=[i for i in range(1, 100)],
Expand All @@ -67,38 +71,36 @@ def dae_scan_plan() -> Generator[Msg, None, None]:
controller.run_number.set_name("run number")
reducer.intensity.set_name("normalized counts")

yield from ensure_connected(block, dae, force_reconnect=True)
yield from ensure_connected(block, blocka, blockb, force_reconnect=True)
print(reducer.intensity.name)

lf = LiveFit(Gaussian.fit(), y=blockb.name, x=block.name, yerr=blocka.name)
fig, ax = plt.subplots()

@subs_decorator(
[
HumanReadableFileCallback(
Path("C:\\") / "instrument" / "var" / "logs" / "bluesky" / "output_files",
[
block.name,
controller.run_number.name,
reducer.intensity.name,
reducer.det_counts.name,
dae.good_frames.name,
],
),
LivePlot(y=reducer.intensity.name, x=block.name, marker="x", linestyle="none"),
LiveTable(
[
block.name,
controller.run_number.name,
reducer.intensity.name,
reducer.intensity_stddev.name,
reducer.det_counts.name,
reducer.det_counts_stddev.name,
dae.good_frames.name,
]
# HumanReadableFileCallback(
# Path("C:\\") / "instrument" / "var" / "logs" / "bluesky" / "output_files",
# [
# block.name,
# controller.run_number.name,
# reducer.intensity.name,
# reducer.det_counts.name,
# dae.good_frames.name,
# ],
# ),
LiveFitPlot(lf, ax=ax, color="r", num_points=1000),
LivePlot(
y=blockb.name, x=block.name, yerr=blocka.name, marker="x", linestyle="none", ax=ax
),
LiveTable([block.name, blocka.name, blockb.name]),
]
)
def _inner() -> Generator[Msg, None, None]:
num_points = 3
yield from bps.mv(dae.number_of_periods, num_points)
yield from bp.scan([dae], block, 0, 10, num=num_points)
num_points = 10
# yield from bps.mv(dae.number_of_periods, num_points)
yield from bps.read(blocka)
yield from bp.scan([blockb, blocka], block, 0, 10, num=num_points)

yield from _inner()

Expand Down
68 changes: 50 additions & 18 deletions src/ibex_bluesky_core/callbacks/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""For IBEX Bluesky scan fitting."""

import logging
import warnings
from typing import Callable

import lmfit
import numpy as np
import numpy.typing as npt
from bluesky.callbacks import LiveFit as _DefaultLiveFit
from bluesky.callbacks.core import make_class_safe
from event_model.documents.event import Event


class FitMethod:
Expand Down Expand Up @@ -50,31 +52,48 @@ class LiveFit(_DefaultLiveFit):
"""Live fit, customized for IBEX."""

def __init__(
self,
method: FitMethod,
y: str,
x: str,
*,
update_every: int = 1,
self, method: FitMethod, y: str, x: str, *, update_every: int = 1, yerr: str | None = None
) -> None:
"""Call Bluesky LiveFit with assumption that there is only one independant variable.

Args:
method (FitMethod): The FitMethod (Model & Guess) to use when fitting.
y (str): The name of the dependant variable.
x (str): The name of the independant variable.
update_every (int): How often to update the fit. (seconds)
update_every (int, optional): How often to update the fit. (seconds)
yerr (str or None, optional): Name of field in the Event document
that provides standard deviation for each Y value

"""
self.method = method
self.yerr = yerr
self.weight_data = []

super().__init__(
model=method.model,
y=y,
independent_vars={"x": x},
update_every=update_every,
model=method.model, y=y, independent_vars={"x": x}, update_every=update_every
)

def event(self, doc: Event) -> None:
"""When an event is received, update caches."""
weight = None
if self.yerr is not None:
try:
weight = 1 / doc["data"][self.yerr]
except ZeroDivisionError:
warnings.warn(
"standard deviation for y is 0, therefore applying weight of 0 on fit",
stacklevel=1,
)
weight = 0.0

self.update_weight(weight)
super().event(doc)

def update_weight(self, weight: float | None = 0.0) -> None:
"""Update uncertainties cache."""
if self.yerr is not None:
self.weight_data.append(weight)

def update_fit(self) -> None:
"""Use the provided guess function with the most recent x and y values after every update.

Expand All @@ -85,10 +104,23 @@ def update_fit(self) -> None:
None

"""
self.init_guess = self.method.guess(
np.array(next(iter(self.independent_vars_data.values()))),
np.array(self.ydata),
# Calls the guess function on the set of data already collected in the run
)

super().update_fit()
n = len(self.model.param_names)
if len(self.ydata) < n:
warnings.warn(
f"LiveFitPlot cannot update fit until there are at least {n} data points",
stacklevel=1,
)
else:
self.init_guess = self.method.guess(
np.array(next(iter(self.independent_vars_data.values()))),
np.array(self.ydata),
# Calls the guess function on the set of data already collected in the run
)

kwargs = {}
kwargs.update(self.independent_vars_data)
kwargs.update(self.init_guess)
self.result = self.model.fit(
self.ydata, weights=None if self.yerr is None else self.weight_data, **kwargs
)
self.__stale = False
1 change: 0 additions & 1 deletion src/ibex_bluesky_core/callbacks/fitting/fitting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ def guess(
) -> Callable[[npt.NDArray[np.float64], npt.NDArray[np.float64]], dict[str, lmfit.Parameter]]:
"""Linear Guessing."""
return Polynomial.guess(1)
# Uses polynomial guessing with a degree of 1


class Polynomial(Fit):
Expand Down
31 changes: 25 additions & 6 deletions src/ibex_bluesky_core/callbacks/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import matplotlib
import matplotlib.pyplot as plt
from bluesky.callbacks import LivePlot as _DefaultLivePlot
from bluesky.callbacks.core import make_class_safe
from bluesky.callbacks.core import get_obj_fields, make_class_safe
from event_model.documents import Event, RunStart

logger = logging.getLogger(__name__)
Expand All @@ -15,18 +15,37 @@
class LivePlot(_DefaultLivePlot):
"""Live plot, customized for IBEX."""

def __init__(self, y, x=None, yerr=None, *args, **kwargs):
super().__init__(y=y, x=x, *args, **kwargs)
if yerr is not None:
self.yerr, *others = get_obj_fields([yerr])
else:
self.yerr = None
self.yerr_data = []

def _show_plot(self) -> None:
# Play nicely with the "normal" backends too - only force show if we're
# actually using our custom backend.
if "genie_python" in matplotlib.get_backend():
plt.show()

def event(self, doc: Event):
"""Process an event document (delegate to superclass, then show the plot)."""
new_yerr = None if self.yerr is None else doc["data"][self.yerr]
self.update_yerr(new_yerr)
super().event(doc)
self._show_plot()

def update_plot(self):
if self.yerr is not None:
self.ax.errorbar(x=self.x_data, y=self.y_data, yerr=self.yerr_data, fmt="none")
super().update_plot()

def update_yerr(self, y_err):
# super.update_caches(x, y)
self.yerr_data.append(y_err)

def start(self, doc: RunStart) -> None:
"""Process an start document (delegate to superclass, then show the plot)."""
super().start(doc)
self._show_plot()

def event(self, doc: Event) -> None:
"""Process an event document (delegate to superclass, then show the plot)."""
super().event(doc)
self._show_plot()
11 changes: 10 additions & 1 deletion src/ibex_bluesky_core/devices/dae/dae_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ophyd_async.core import SignalR, StandardReadable
from ophyd_async.epics.signal import epics_signal_r

VARIANCE_ADDITION = 0.5

class DaeSpectra(StandardReadable):
"""Subdevice for a single DAE spectra."""
Expand Down Expand Up @@ -107,7 +108,15 @@ async def read_spectrum_dataarray(self) -> sc.DataArray:
if unit is None:
raise ValueError("Could not determine engineering units of tof edges.")

# TODO add reference to ADR


return sc.DataArray(
data=sc.Variable(dims=["tof"], values=counts, variances=counts, unit=sc.units.counts),
data=sc.Variable(
dims=["tof"],
values=counts,
variances=counts + VARIANCE_ADDITION,
unit=sc.units.counts,
),
coords={"tof": sc.array(dims=["tof"], values=tof_edges, unit=sc.Unit(unit))},
)
70 changes: 70 additions & 0 deletions tests/callbacks/fitting/test_fitting_callback.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import warnings
from unittest import mock
from unittest.mock import MagicMock

import lmfit
import numpy as np
import numpy.typing as npt

from ibex_bluesky_core.callbacks.fitting import FitMethod, LiveFit
from ibex_bluesky_core.callbacks.fitting.fitting_utils import Linear


def test_guess_called():
Expand Down Expand Up @@ -92,3 +95,70 @@ def model(x: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
)

model_mock.assert_called()


def test_model_called_with_weights_if_yerr_is_given():
def guess(x: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> dict[str, lmfit.Parameter]:
return {}

model = lmfit.Model(lambda x: x)
model.fit = MagicMock()
method = FitMethod(model=model, guess=guess)
lf = LiveFit(method, y="y", x="x", yerr="yerr")

x = 1
y = 2
yerr = 3

lf.event(
{
"data": { # type: ignore
"y": y,
"x": x,
"yerr": yerr,
}
}
)

model.fit.assert_called_with([y], weights=[1 / yerr], x=[1])


def test_warning_given_if_yerr_is_0():
def guess(x: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> dict[str, lmfit.Parameter]:
return {}

model = lmfit.Model(lambda x: x)
model.fit = MagicMock()
method = FitMethod(model=model, guess=guess)
lf = LiveFit(method, y="y", x="x", yerr="yerr")

x = 1
y = 2
yerr = 0

with warnings.catch_warnings(record=True) as w:
lf.event(
{
"data": { # type: ignore
"y": y,
"x": x,
"yerr": yerr,
}
}
)

model.fit.assert_called_with([y], weights=[0.0], x=[1])

assert len(w) == 1
assert "standard deviation for y is 0, therefore applying weight of 0 on fit" in str(
w[-1].message
)


def test_warning_if_no_y_data():
with warnings.catch_warnings(record=True) as w:
lf = LiveFit(Linear.fit(), y="y", x="x", yerr="yerr")
lf.update_fit()

assert len(w) == 1
assert "LiveFitPlot cannot update fit until there are at least" in str(w[-1].message)
Loading