Skip to content

Commit 2a9ae09

Browse files
committed
feat(rfimask): Add advanced RFI masking with frequency reversal support
- Introduced new `astrorfimask` command for batch RFI masking - Implemented Savitzky-Golay and Spectral Kurtosis algorithms - Added frequency reversal for descending order data - Enhanced multiprocessing support for large datasets - Improved error handling and progress tracking - Updated documentation with detailed usage examples - Added model download functionality for default weights - Removed deprecated preprocess configuration
1 parent 8d9c009 commit 2a9ae09

File tree

12 files changed

+719
-120
lines changed

12 files changed

+719
-120
lines changed

.github/workflows/ubuntu-22.04-build.yml

Lines changed: 0 additions & 76 deletions
This file was deleted.

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ Below are the commonly used YAML configuration options:
149149
| `confidence` | Detection confidence threshold | `0.4` | `0.372` |
150150
| `snrhold` | SNR threshold | `5` | `5` |
151151
| `modelname` | Detector model type | current only `yolov11n`
152-
| `modelpath` | Path to model weights | `/path/to/yolo11n_0816_v1.pt` | - |
152+
| `modelpath` | Path to customodel weights | `model.pt` | - |
153153
| `timedownfactor` | Time downsampling factor | `8` | `1` |
154154

155155
</details>
@@ -238,7 +238,7 @@ maskdir: /path/to/maskdir
238238

239239
</details>
240240

241-
<details>
241+
<!-- <details>
242242
<summary>Data Processing</summary>
243243

244244
| Option | Description | Example |
@@ -252,7 +252,7 @@ preprocess:
252252
- guassion: 1 5 # Gaussian filter parameters
253253
```
254254

255-
</details>
255+
</details> -->
256256

257257
<details>
258258
<summary>Plotting Configuration</summary>
@@ -300,7 +300,8 @@ cputhread: 32
300300
301301
snrhold: 5
302302
modelname: yolov11n
303-
modelpath: yolo11n_0816_v1.pt
303+
# use default
304+
# modelpath: yolo11n_0816_v1.pt
304305
305306
rfi: ai
306307
maskfile: file.bad_chans
@@ -349,7 +350,8 @@ detgpu: 1
349350
cputhread: 64
350351
351352
modelname: yolov11n
352-
modepath: yolo11n_0816_v1.pt
353+
# use default
354+
# modelpath: yolo11n_0816_v1.pt
353355
plotworker: 16
354356
355357
rfi: ai

docs/README_zh-CN.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@
141141
| `confidence` | 检测置信度阈值 | `0.4` | `0.372` |
142142
| `snrhold` | 信噪比阈值 | `5` | `5` |
143143
| `modelname` | 检测器模型类型 | 目前仅支持 `yolov11n` |
144-
| `modelpath` | 模型权重路径 | `/path/to/yolo11n_0816_v1.pt` | - |
144+
| `modelpath` | 自定义模型权重路径 | `/path/to/yolo11n_0816_v1.pt` | - |
145145
| `timedownfactor` | 时间降采样因子 | `8` | `1` |
146146

147147
</details>
@@ -230,7 +230,7 @@ maskdir: /path/to/maskdir
230230

231231
</details>
232232

233-
<details>
233+
<!-- <details>
234234
<summary>数据处理</summary>
235235

236236
| 选项 | 描述 | 示例 |
@@ -244,7 +244,7 @@ preprocess:
244244
- guassion: 1 5 # 高斯滤波参数
245245
```
246246

247-
</details>
247+
</details> -->
248248

249249
<details>
250250
<summary>绘图配置</summary>
@@ -291,7 +291,8 @@ cputhread: 32
291291
292292
snrhold: 5
293293
modelname: yolov11n
294-
modelpath: yolo11n_0816_v1.pt
294+
# use default
295+
# modelpath: yolo11n_0816_v1.pt
295296
296297
rfi: ai
297298
maskfile: file.bad_chans
@@ -340,7 +341,8 @@ detgpu: 1
340341
cputhread: 64
341342
342343
modelname: yolov11n
343-
modepath: yolo11n_0816_v1.pt
344+
# use default
345+
# modelpath: yolo11n_0816_v1.pt
344346
plotworker: 16
345347
346348
rfi: ai

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "scikit_build_core.build"
1212

1313
[project]
1414
name = "pulseflow"
15-
version = "0.1.0"
15+
version = "0.1.1"
1616
description = "High-performance radio astronomy single pulse Detect lib"
1717
requires-python = ">=3.10"
1818
readme = "README.md"
@@ -44,6 +44,7 @@ dev = ["pytest>=6.0", "pytest-cov", "black", "flake8"]
4444

4545
[project.scripts]
4646
astroflow = "astroflow.cli:main"
47+
astrorfimask = "astroflow.rfimask:astrorfimask"
4748

4849
[tool.setuptools]
4950
include-package-data = true

python/astroflow/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,4 @@ def main():
132132
print(f"Error: {e}")
133133
exit(1)
134134
exit(0)
135+

python/astroflow/config/taskconfig.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import os
22
from sympy import preorder_traversal
33
import yaml
4+
import urllib.request
5+
6+
47

58
CENTERNET = 0
69
YOLOV11N = 1
710
DETECTNET = 2
811
COMBINENET = 3
912

13+
1014
class TaskConfig:
1115
_instance = None
1216

@@ -81,15 +85,34 @@ def _checker_freqrange(self, freqrange):
8185
else:
8286
raise ValueError("Invalid format for freqrange in config file.")
8387

84-
def _checker_preprocess(self, preprocess):
85-
if isinstance(preprocess, list):
86-
for item in preprocess:
87-
if not isinstance(item, dict) or len(item) != 1:
88-
raise ValueError(
89-
"Invalid format for preprocess in config file. Each item should be a single key-value pair."
90-
)
91-
else:
92-
raise ValueError("Invalid format for preprocess in config file.")
88+
# def _checker_preprocess(self, preprocess):
89+
# if isinstance(preprocess, list):
90+
# for item in preprocess:
91+
# if not isinstance(item, dict) or len(item) != 1:
92+
# raise ValueError(
93+
# "Invalid format for preprocess in config file. Each item should be a single key-value pair."
94+
# )
95+
# else:
96+
# raise ValueError("Invalid format for preprocess in config file.")
97+
98+
99+
100+
def get_model(self):
101+
model_url_path = "https://github.com/lintian233/astroflow/releases/download/v0.1.1/yolo11n_0816_v1.pt"
102+
config_dir = "~/.config/astroflow"
103+
config_dir = os.path.expanduser(config_dir)
104+
os.makedirs(config_dir, exist_ok=True)
105+
106+
model_filename = "yolov1n_0816_v1.pt"
107+
local_model_path = os.path.join(config_dir, model_filename)
108+
109+
if not os.path.exists(local_model_path):
110+
print(f"Downloading model from {model_url_path}...")
111+
urllib.request.urlretrieve(model_url_path, local_model_path)
112+
print(f"Model downloaded to {local_model_path}")
113+
114+
return local_model_path
115+
93116

94117
def _checker_dm_limt(self, dm_limt):
95118
if isinstance(dm_limt, list):
@@ -177,18 +200,16 @@ def candpath(self):
177200

178201
@property
179202
def modelpath(self):
180-
model_url_path = "" # yolo11n_pulsedetect.pt github release path
181203
# https get
182204
modelpath = self._config_data.get("modelpath")
183205
if modelpath is None:
184-
raise ValueError("modelpath not found in config file.")
206+
modelpath = self.get_model()
185207
if not isinstance(modelpath, str):
186208
raise ValueError("modelpath must be a string.")
187209
if not os.path.exists(modelpath):
188210
raise FileNotFoundError(f"Model path {modelpath} does not exist.")
189211
return modelpath
190212

191-
192213
@property
193214
def dedgpu(self):
194215
return self._config_data.get("dedgpu", 0)

python/astroflow/dataset/generate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,16 @@ def _create_detector_and_plotter(task_config: TaskConfig) -> Tuple[Union[CenterN
182182

183183
if task_config.modelname == CENTERNET:
184184
detector = CenterNetFrbDetector(
185-
task_config.dm_limt, task_config.preprocess, task_config.confidence
185+
task_config.dm_limt, None, task_config.confidence
186186
)
187187
elif task_config.modelname == YOLOV11N:
188188
detector = Yolo11nFrbDetector(
189-
task_config.dm_limt, task_config.preprocess, task_config.confidence
189+
task_config.dm_limt, None, task_config.confidence
190190
)
191191
else:
192192
logger.warning(f"Unknown model name {task_config.modelname}, using CenterNet")
193193
detector = CenterNetFrbDetector(
194-
task_config.dm_limt, task_config.preprocess, task_config.confidence
194+
task_config.dm_limt, None, task_config.confidence
195195
)
196196

197197
return detector, plotter
@@ -229,14 +229,14 @@ def _process_single_file(
229229
base_dir += f"-{config.freq_start}MHz-{config.freq_end}MHz"
230230
base_dir += f"-{config.dm_step}DM-{config.t_sample}s"
231231

232-
cached_dir = os.path.join(output_dir, "cached").lower()
233-
file_dir = os.path.join(cached_dir, base_dir, file_basename).lower()
232+
cached_dir = os.path.join(output_dir, "cached")
233+
file_dir = os.path.join(cached_dir, base_dir, file_basename)
234234

235235

236236
if os.path.exists(file_dir):
237237
logger.info(f"Skipping already processed file: {file_basename}")
238238
# 检查 candidate_detect_dir 目录下是否有文件,如果没有则 detection_flag = 0
239-
candidate_detect_dir = os.path.join(output_dir, "candidate", file_basename).lower()
239+
candidate_detect_dir = os.path.join(output_dir, "candidate", file_basename)
240240
if any(os.scandir(candidate_detect_dir)):
241241
file_detected = True
242242
continue

python/astroflow/plotter.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,9 +992,25 @@ def plot_candidate(
992992
spec_data, noise_range=None, threshold_sigma=5, toa_sample_idx=int((peak_time - spec_tstart) / header.tsamp)
993993
)
994994

995+
# Check SNR against threshold
996+
snrhold = taskconfig.snrhold
997+
if snr < snrhold:
998+
print(f"Warning: SNR {snr:.2f} is below threshold {snrhold}. Skipping spectrum plot.")
999+
plt.close('all')
1000+
if origin_data is not None:
1001+
if hasattr(origin_data, "close"):
1002+
try:
1003+
origin_data.close()
1004+
except Exception:
1005+
pass
1006+
del origin_data
1007+
1008+
del spectrum, spec_data, initial_spectrum, initial_spec_data
1009+
gc.collect()
1010+
return
1011+
9951012
peak_time = spec_tstart + (peak_idx + 0.5) * header.tsamp
9961013
pulse_width_ms = pulse_width * header.tsamp * 1e3 if pulse_width > 0 else -1 # Convert to milliseconds
997-
9981014

9991015
# Create time and frequency axes
10001016
spec_time_axis = np.linspace(spec_tstart, spec_tend, spectrum.ntimes)

0 commit comments

Comments
 (0)