1-
21#ifndef _RFIMARKER_H
32#define _RFIMARKER_H
43
54#include < vector>
65#include < string>
7- #include < iostream>
8- #include < fstream>
9- #include < sstream>
10- #include < omp.h>
11-
12-
6+ #include < cuda_runtime.h>
7+ #include < cstdint>
8+
9+ /* *
10+ * GPU RFI marker:
11+ * - load_mask(): 在 CPU 读入坏道列表,并上传到 GPU
12+ * - mark_rfi(): 传入 device 指针 (d_data),在 kernel 中把坏道通道清零
13+ *
14+ * 数据布局假设为: data[sample * num_channels + chan] (row-major: time-major)
15+ */
1316template <typename T>
1417class RfiMarker {
1518public:
1619 RfiMarker ();
17- RfiMarker (const char * mask_file); // Constructor that takes a mask file
18- RfiMarker (std::string mask_file) : RfiMarker(mask_file.c_str()) {} // Constructor that takes a string
19- ~RfiMarker () = default ;
20+ explicit RfiMarker (const char * mask_file);
21+ explicit RfiMarker (const std::string& mask_file) : RfiMarker(mask_file.c_str()) {}
22+ ~RfiMarker ();
2023
21- std::vector<int > bad_channels; // Vector to store bad channels
24+ // 将坏道置零(在 GPU 上执行)。d_data 必须是 device 指针
25+ void mark_rfi (T* d_data,
26+ unsigned int num_channels,
27+ unsigned int num_samples,
28+ cudaStream_t stream = 0 );
2229
23- // Method to load the RFI mask
30+ // 重新加载掩码文件(会同步上传到 GPU);文件不存在或为空则视为无坏道
2431 void load_mask (const char * mask_file);
2532
26- void mark_rfi (T* data, uint num_channels, uint num_samples);
27-
33+ // Host 侧只读坏道列表
34+ const std::vector< int >& get_bad_channels () const { return bad_channels_; }
2835
29- const std::vector<int >& get_bad_channels () const {
30- return bad_channels;
31- }
32- };
36+ private:
37+ void upload_bad_channels_to_device ();
3338
39+ std::vector<int > bad_channels_; // host: 坏道索引
40+ int * d_bad_channels_ = nullptr ; // device: 坏道索引
41+ size_t n_bad_ = 0 ;
42+ };
3443
44+ // ------------ 显式实例化声明(由 .cu 文件提供定义) ------------
45+ extern template class RfiMarker <uint8_t >;
46+ extern template class RfiMarker <uint16_t >;
47+ extern template class RfiMarker <uint32_t >;
3548
36- template <typename T>
37- RfiMarker<T>::RfiMarker() {
38- load_mask (" mask.txt" ); // Default mask file
39- }
40-
41- template <typename T>
42- RfiMarker<T>::RfiMarker(const char * mask_file) {
43- load_mask (mask_file); // Load the mask from the provided file
44- }
45-
46- template <typename T>
47- void RfiMarker<T>::mark_rfi(T* data, uint num_channels, uint num_samples) {
48- // Iterate through the bad channels and mark them in the data
49- #pragma omp parallel for
50- for (int chan : bad_channels) {
51- if (chan >= 0 && chan < num_channels) {
52- #pragma omp simd
53- for (uint sample = 0 ; sample < num_samples; ++sample) {
54- // Set the data for the bad channel to zero
55- data[sample * num_channels + chan] = 0 ;
56- }
57- } else {
58- std::cerr << " Warning: Bad channel index " << chan << " out of range." << std::endl;
59- }
60- }
61- std::cout << " RFI marking completed. Bad channels: " << bad_channels.size () << std::endl;
62- }
63-
64- template <typename T>
65- void RfiMarker<T>::load_mask(const char * mask_file) {
66- // open the mask file
67- std::ifstream file (mask_file);
68- if (!file.is_open ()) {
69- std::cerr << " Error opening mask file: " << mask_file << std::endl
70- << " Please check the file path and try again." << std::endl;
71- return ;
72- }
73- // if tempty continue
74- if (file.peek () == std::ifstream::traits_type::eof ()) {
75- return ;
76- }
77-
78- std::string line;
79- while (std::getline (file, line)) {
80- std::stringstream ss (line);
81- int chan;
82- while (ss >> chan) {
83- bad_channels.push_back (chan);
84- }
85- }
86- file.close ();
87- }
88- #endif // _RFIMARKER_H
49+ #endif // _RFIMARKER_H
0 commit comments