Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions filterpy2/kalman/kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@
from copy import deepcopy
from math import log, exp, sqrt
import sys
from typing import Callable
import numpy as np
from numpy import dot, zeros, eye, isscalar, shape
from numpy import dot, zeros, eye, isscalar, shape, subtract, typing as npt
import numpy.linalg as linalg
from filterpy2.stats import logpdf
from filterpy2.common import pretty_str, reshape_z
Expand Down Expand Up @@ -384,7 +385,7 @@ class KalmanFilter(object):

"""

def __init__(self, dim_x, dim_z, dim_u=0):
def __init__(self, dim_x: int, dim_z: int, dim_u: int=0, residual_z_fn: Callable[[npt.NDArray], npt.NDArray]=subtract):
if dim_x < 1:
raise ValueError("dim_x must be 1 or greater")
if dim_z < 1:
Expand All @@ -406,6 +407,7 @@ def __init__(self, dim_x, dim_z, dim_u=0):
self._alpha_sq = 1.0 # fading memory control
self.M = np.zeros((dim_x, dim_z)) # process-measurement cross correlation
self.z = np.array([[None] * self.dim_z]).T
self.residual_z_fn = residual_z_fn

# gain and residual are computed during the innovation step. We
# save them so that in case you want to inspect them for various
Expand Down Expand Up @@ -527,7 +529,7 @@ def update(self, z, R=None, H=None):

# y = z - Hx
# error (residual) between measurement and prediction
self.y = z - dot(H, self.x)
self.y = self.residual_z_fn(z, dot(H, self.x))

# common subexpression for speed
PHT = dot(self.P, H.T)
Expand Down