Skip to content

Commit 9c86068

Browse files
committed
fix(rfi): Optimize RFI processing slice logic to ensure coverage of all samples required for dedispersion
- Extend RFI processing slice range to avoid boundary issues - Optimize comments in the downsampling function for better readability - Add dtrend configuration option to allow users to choose whether to detrend the subband matrix - Optimize resource cleanup logic to ensure proper release of origin_data
1 parent 2899a88 commit 9c86068

File tree

2 files changed

+43
-31
lines changed

2 files changed

+43
-31
lines changed

include/cpucal.hpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -284,25 +284,27 @@ Spectrum<T> dedispered_fil_with_dm(
284284
throw std::invalid_argument("Time window too short for this DM and band (no valid samples after dedispersion).");
285285
}
286286

287-
// ---- RFI:只在“局部输入切片”上运行(长度 = t_len_eff + delay_max_idx)----
288-
// 这样解色散访问到的所有原始样本都被一致地标注/掩膜。
287+
// ---- RFI:只在“局部输入切片”上运行 ----
288+
// 扩展窗口以覆盖解色散所需的所有样本
289289
T* origin_data = static_cast<T*>(fil->data);
290-
const size_t slice_len_for_rfi = t_len_eff + (size_t)delay_max_idx;
291-
T* slice_ptr = origin_data + t_start_idx * fil->nchans;
290+
size_t rfi_t_start_idx = (t_start_idx > (size_t)delay_max_idx) ? (t_start_idx - delay_max_idx) : 0;
291+
size_t rfi_offset_from_t_start = t_start_idx - rfi_t_start_idx;
292+
T* slice_ptr_for_rfi = origin_data + rfi_t_start_idx * fil->nchans;
293+
size_t slice_len_for_rfi = std::min(t_len_eff + rfi_offset_from_t_start + 2 * delay_max_idx, (size_t)fil->ndata - rfi_t_start_idx);
292294

293295
RfiMarkerCPU<T> rfi_marker(maskfile);
294296
if (rficfg.use_iqrm) {
295297
auto win_masks = iqrm_cuda::rfi_iqrm_gpu_host<T>(
296-
slice_ptr, // 指向局部起点
298+
slice_ptr_for_rfi, // 指向扩展切片的起点
297299
chan_start, chan_end_excl,
298300
slice_len_for_rfi, // 覆盖局部 + 最大延时
299301
fil->nchans,
300302
fil->tsamp, rficfg);
301-
rfi_marker.mask(slice_ptr, fil->nchans, slice_len_for_rfi, win_masks);
303+
rfi_marker.mask(slice_ptr_for_rfi, fil->nchans, slice_len_for_rfi, win_masks);
302304
}
303305
if (rficfg.use_mask) {
304-
// 静态掩膜同样只作用在切片上(不必动全局)
305-
rfi_marker.mark_rfi(slice_ptr, fil->nchans, slice_len_for_rfi);
306+
// 静态掩膜同样只作用在扩展切片上
307+
rfi_marker.mark_rfi(slice_ptr_for_rfi, fil->nchans, slice_len_for_rfi);
306308
}
307309

308310
// ---- 输出光谱 ----
@@ -326,10 +328,9 @@ Spectrum<T> dedispered_fil_with_dm(
326328
#pragma omp simd
327329
for (ptrdiff_t ch = (ptrdiff_t)chan_start; ch < (ptrdiff_t)chan_end_excl; ++ch) {
328330
const int d = dm_delays[ch - chan_start];
329-
const size_t src_idx = (size_t)ti + (size_t)d; // 相对于 slice_ptr 的偏移
330-
// 由 t_len_eff 的定义,src_idx < slice_len_for_rfi 恒成立,无需额外边界判断
331+
const size_t src_idx = t_start_idx + (size_t)ti + (size_t)d; // 相对于 origin_data 的绝对偏移
331332
result.data[(size_t)ti * (size_t)result.nchans + (size_t)(ch - chan_start)]
332-
= slice_ptr[src_idx * fil->nchans + (size_t)ch];
333+
= origin_data[src_idx * fil->nchans + (size_t)ch];
333334
}
334335
}
335336

@@ -420,20 +421,23 @@ Spectrum<T> dedisperse_spec_with_dm(
420421
}
421422

422423
// ---- RFI(局部+最大延时)----
423-
T* slice_ptr = spec + t_start_idx * header.nchans;
424-
size_t slice_len_for_rfi = t_len_eff + (size_t)delay_max_idx;
424+
size_t rfi_t_start_idx = (t_start_idx > (size_t)delay_max_idx) ? (t_start_idx - delay_max_idx) : 0;
425+
size_t rfi_offset_from_t_start = t_start_idx - rfi_t_start_idx;
426+
T* slice_ptr_for_rfi = spec + rfi_t_start_idx * header.nchans;
427+
size_t slice_len_for_rfi = std::min(t_len_eff + rfi_offset_from_t_start + 2 * delay_max_idx, header.ndata - rfi_t_start_idx);
428+
425429
RfiMarkerCPU<T> rfi_marker(maskfile);
426430
if (rficfg.use_iqrm) {
427431
auto win_masks = iqrm_cuda::rfi_iqrm_gpu_host<T>(
428-
slice_ptr,
432+
slice_ptr_for_rfi,
429433
chan_start, chan_end_excl,
430434
slice_len_for_rfi,
431435
header.nchans,
432436
header.tsamp, rficfg);
433-
rfi_marker.mask(slice_ptr, header.nchans, slice_len_for_rfi, win_masks);
437+
rfi_marker.mask(slice_ptr_for_rfi, header.nchans, slice_len_for_rfi, win_masks);
434438
}
435439
if (rficfg.use_mask) {
436-
rfi_marker.mark_rfi(slice_ptr, header.nchans, slice_len_for_rfi);
440+
rfi_marker.mark_rfi(slice_ptr_for_rfi, header.nchans, slice_len_for_rfi);
437441
}
438442

439443
// ---- 构造输出 ----
@@ -457,9 +461,9 @@ Spectrum<T> dedisperse_spec_with_dm(
457461
#pragma omp simd
458462
for (ptrdiff_t ch = (ptrdiff_t)chan_start; ch < (ptrdiff_t)chan_end_excl; ++ch) {
459463
int d = dm_delays[ch - chan_start];
460-
size_t src_idx = (size_t)ti + (size_t)d; // 相对于 slice_ptr
464+
size_t src_idx = t_start_idx + (size_t)ti + (size_t)d; // 相对于 spec
461465
result.data[(size_t)ti * result.nchans + (size_t)(ch - chan_start)]
462-
= slice_ptr[src_idx * header.nchans + (size_t)ch];
466+
= spec[src_idx * header.nchans + (size_t)ch];
463467
}
464468
}
465469

python/astroflow/plotter.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -788,9 +788,15 @@ def downsample_freq_weighted_vec(spec_data, freq_axis, n_out):
788788
"""
789789
完全向量化的频率方向降采样。
790790
保证能量守恒 & 无跑频。
791-
spec_data: [ntime, nfreq_in]
792-
freq_axis: 频率中心(升序)
793-
n_out: 目标子带数
791+
792+
参数
793+
------
794+
spec_data : ndarray
795+
[ntime, nfreq_in] 的动态谱
796+
freq_axis : ndarray
797+
频率中心(升序)。
798+
n_out : int
799+
目标子带数。
794800
"""
795801
ntime, nfreq_in = spec_data.shape
796802

@@ -848,8 +854,8 @@ def _setup_subband_spectrum_plots(fig, gs, spec_data, spec_time_axis, spec_freq_
848854
)
849855
# c_end_time = time.time()s
850856
# print(f"Subband processing time: {c_end_time - curr_time:.3f} seconds")
851-
852-
subband_matrix = _detrend(subband_matrix, axis=0, type='linear')
857+
if specconfig.get("dtrend", False):
858+
subband_matrix = _detrend(subband_matrix, axis=0, type='linear')
853859
# subband_matrix = _detrend_frequency(subband_matrix.T, poly_order=6).T
854860

855861
if specconfig.get("norm", True):
@@ -973,6 +979,7 @@ def plot_candidate(
973979
ValueError: If file_path has unsupported extension
974980
Exception: If data loading or processing fails
975981
"""
982+
origin_data = None
976983
try:
977984
# Parse candidate information
978985
dm, toa, freq_start, freq_end, dmt_idx, ref_toa, bbox = _parse_candidate_info(candinfo)
@@ -1000,7 +1007,6 @@ def plot_candidate(
10001007
ax_time, ax_main, ax_dm = _setup_dm_plots(fig, gs, dm_data, time_axis, dm_axis, dm_vmin, dm_vmax, dm, toa)
10011008

10021009
# Load and process spectrum data
1003-
origin_data = None
10041010
try:
10051011
origin_data = _load_data_file(file_path)
10061012
header = origin_data.header()
@@ -1051,12 +1057,13 @@ def plot_candidate(
10511057
if snr < snrhold:
10521058
plt.close('all')
10531059
if origin_data is not None:
1054-
if hasattr(origin_data, "close"):
1060+
close_method = getattr(origin_data, "close", None)
1061+
if callable(close_method):
10551062
try:
1056-
origin_data.close()
1063+
close_method()
10571064
except Exception:
10581065
pass
1059-
del origin_data
1066+
origin_data = None
10601067

10611068
del spectrum, spec_data, initial_spectrum, initial_spec_data
10621069
gc.collect()
@@ -1137,11 +1144,12 @@ def plot_candidate(
11371144
finally:
11381145
# Cleanup
11391146
plt.close('all')
1140-
if 'origin_data' in locals() and origin_data is not None:
1141-
if hasattr(origin_data, "close"):
1147+
if origin_data is not None:
1148+
close_method = getattr(origin_data, "close", None)
1149+
if callable(close_method):
11421150
try:
1143-
origin_data.close()
1151+
close_method()
11441152
except Exception:
11451153
pass
1146-
del origin_data
1154+
origin_data = None
11471155
gc.collect()

0 commit comments

Comments
 (0)