Skip to content

Commit 3442e62

Browse files
committed
#146 Refactor anomaly detection logic in series with NaN to improve clarity.
Simplified and clarified logical operations by introducing a mask for readability and maintainability. Added test cases to ensure proper handling of NaN values during anomaly detection, covering both interior and edge cases.
1 parent becb56e commit 3442e62

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

tests/test_TradeRoutines.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,22 @@ def test_HampelFilterPositive(self):
604604
# Serves as a stability test with noise.
605605
[False, False, False, False, False, False],
606606
),
607+
(
608+
{
609+
"series": pd.Series([1, 2, np.nan, 2, 1]),
610+
"window": 3, "sigma": 3, "scaleFactor": 1.4826,
611+
},
612+
# The presence of NaN in the middle should not cause anomalies. The method skips NaNs properly.
613+
[False, False, False, False, False],
614+
),
615+
(
616+
{
617+
"series": pd.Series([np.nan, 1, 1, 11111111, 1, 1, np.nan]),
618+
"window": 1, "sigma": 1.1, "scaleFactor": 1.4826,
619+
},
620+
# The presence of NaN at the edges should not cause anomalies. The method skips NaNs properly.
621+
[False, False, False, True, False, False, False],
622+
),
607623
]
608624

609625
for test in testData:

tksbrokerapi/TradeRoutines.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,8 @@ def HampelFilter(series: Union[list, pd.Series], window: int = 5, sigma: float =
714714
# Step 3: Detect anomalies for valid central values:
715715
new = pd.Series(False, index=series.index)
716716
threshold = sigma * scaleFactor * rollingMAD
717-
new[rollingMAD.notna()] = delta[rollingMAD.notna()] > threshold[rollingMAD.notna()]
717+
mask = rollingMAD.notna()
718+
new.loc[mask] = (delta[mask] > threshold[mask])
718719

719720
# Step 4: Handle boundary values manually (first and last `window` points):
720721
for i in list(range(window)) + list(range(len(series) - window, len(series))):

0 commit comments

Comments
 (0)