Skip to content

Commit dbe9a89

Browse files
committed
fix failing test
1 parent f35c1fe commit dbe9a89

File tree

1 file changed

+40
-5
lines changed

1 file changed

+40
-5
lines changed

causalpy/reporting.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,48 @@ def _extract_window(result, window, treated_unit=None):
219219
windowed_impact = post_impact.sel(obs_ind=window_coords)
220220
else:
221221
# Integer index
222-
window_coords = result.datapost.index[
223-
(result.datapost.index >= start) & (result.datapost.index <= end)
224-
]
222+
# Ensure start and end are comparable with the index
223+
# Convert to native Python int to avoid type issues
224+
start_val = int(start)
225+
end_val = int(end)
226+
# Use result.datapost.index for filtering, then match with post_impact coordinates
227+
mask = (result.datapost.index >= start_val) & (
228+
result.datapost.index <= end_val
229+
)
230+
window_coords = result.datapost.index[mask]
225231
windowed_impact = post_impact.sel(obs_ind=window_coords)
226232
elif isinstance(window, slice):
227-
# Integer slice
228-
window_coords = result.datapost.index[window]
233+
# Slice window - handle differently for datetime vs integer indices
234+
if isinstance(result.datapost.index, pd.DatetimeIndex):
235+
# For datetime, slice works directly
236+
window_coords = result.datapost.index[window]
237+
else:
238+
# For integer indices, convert slice to value-based filtering
239+
# slice(start, stop, step) -> get all values in [start, stop)
240+
start_val = (
241+
int(window.start)
242+
if window.start is not None
243+
else result.datapost.index.min()
244+
)
245+
stop_val = (
246+
int(window.stop)
247+
if window.stop is not None
248+
else result.datapost.index.max() + 1
249+
)
250+
step = int(window.step) if window.step is not None else 1
251+
# Create boolean mask for values in range
252+
if step == 1:
253+
mask = (result.datapost.index >= start_val) & (
254+
result.datapost.index < stop_val
255+
)
256+
window_coords = result.datapost.index[mask]
257+
else:
258+
# For non-unit step, filter then apply step
259+
mask = (result.datapost.index >= start_val) & (
260+
result.datapost.index < stop_val
261+
)
262+
filtered = result.datapost.index[mask]
263+
window_coords = filtered[::step]
229264
windowed_impact = post_impact.sel(obs_ind=window_coords)
230265
else:
231266
raise ValueError(

0 commit comments

Comments
 (0)