Skip to content

Commit 868895e

Browse files
committed
1.reduce the #selected features. 2.better feature format. 3.add normalize_rag_feature arg when creating/loading rag index db
1 parent d78e83c commit 868895e

File tree

4 files changed

+273
-56
lines changed

4 files changed

+273
-56
lines changed

ecg_bench/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def get_args():
6060
mode_group.add_argument('--retrieved_information', type=str, default='combined', choices=['feature', 'report', 'combined'], help='Type of information to retrieve in output')
6161
mode_group.add_argument('--load_rag_db', type = str, default = None, help = 'Load a RAG database')
6262
mode_group.add_argument('--load_rag_db_idx', type = str, default = None, help = 'Load a RAG database index')
63+
mode_group.add_argument('--normalized_rag_feature', action='store_true', default=True, help='Enable normalization for RAG features and signals')
6364
mode_group.add_argument('--dev', action='store_true', default=None, help='Development mode')
6465
mode_group.add_argument('--log', action='store_true', default=None, help='Enable logging')
6566

ecg_bench/utils/data_loader_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,11 @@ def setup_conversation_template(self, signal = None):
110110
if self.args.retrieval_base in ['feature', 'combined']:
111111
if self.args.dev:
112112
print("🔍 DEBUG: Extracting features")
113-
feature=self.rag_db.feature_extractor.extract_features(signal)
113+
original_feature=self.rag_db.feature_extractor.extract_rag_features(signal)
114+
feature=original_feature
115+
if self.args.normalized_rag_feature:
116+
feature = self.rag_db.query_feature_normalization(original_feature)
117+
signal = self.rag_db.query_signal_lead_normalization(signal)
114118
if self.args.dev:
115119
print("🔍 DEBUG: Features extracted, shape: ", feature.shape)
116120

@@ -161,13 +165,17 @@ def append_messages_to_conv(self, conv, altered_text, signal=None):
161165
message_value = message_value.replace('<ecg>', '')
162166
message_value = message_value.replace('image', 'signal').replace('Image', 'Signal')
163167
if self.args.retrieval_base in ['feature', 'combined'] or self.args.retrieved_information in ['feature','combined']:
164-
feature=self.rag_db.feature_extractor.extract_features(signal)
168+
original_feature=self.rag_db.feature_extractor.extract_rag_features(signal)
169+
feature=original_feature
170+
if self.args.normalized_rag_feature:
171+
feature = self.rag_db.query_feature_normalization(original_feature)
172+
signal = self.rag_db.query_signal_lead_normalization(signal)
165173
if is_human and count == 0:
166174
if self.args.rag and self.args.rag_prompt_mode == 'user_query':
167175
rag_results = self.rag_db.search_similar(query_features=feature, query_signal=signal, k=self.args.rag_k, mode=self.args.retrieval_base)
168176
filtered_rag_results = self.rag_db.format_search(rag_results,self.args.retrieved_information)
169177
if self.args.retrieved_information == 'combined':
170-
message_value = f"<signal>\nFeature Information:\n{feature}\n\n{filtered_rag_results}\n{message_value}"
178+
message_value = f"<signal>\nFeature Information:\n{self.rag_db.convert_features_to_structured(original_feature)}\n\n{filtered_rag_results}\n{message_value}"
171179
elif self.args.retrieved_information == 'report':
172180
message_value = f"<signal>\n{filtered_rag_results}\n{message_value}"
173181
else:

ecg_bench/utils/preprocess_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,61 @@ def extract_features(self, ecg):
10471047

10481048
return np.array(features)
10491049

1050+
def extract_rag_features(self, ecg):
1051+
"""
1052+
Extract a subset of features for RAG applications.
1053+
Keeps only: max, min, dominant_frequency, total_power, spectral_centroid,
1054+
peak_frequency_power, Heart Rate Features, Wavelet Features, average_absolute_difference, root_mean_square_difference
1055+
"""
1056+
features = []
1057+
1058+
for lead in range(ecg.shape[0]):
1059+
lead_signal = ecg[lead, :]
1060+
1061+
# Basic statistical features (only max and min)
1062+
features.extend([
1063+
np.max(lead_signal),
1064+
np.min(lead_signal)
1065+
])
1066+
1067+
# Frequency domain features
1068+
freqs, psd = signal.welch(lead_signal, fs=self.target_sf, nperseg=min(1024, len(lead_signal)))
1069+
total_power = np.sum(psd)
1070+
features.extend([
1071+
total_power, # Total power
1072+
np.max(psd), # Peak frequency power
1073+
freqs[np.argmax(psd)], # Dominant frequency
1074+
])
1075+
1076+
# Spectral centroid with NaN handling
1077+
if total_power > 0:
1078+
spectral_centroid = np.sum(freqs * psd) / total_power
1079+
else:
1080+
spectral_centroid = 0
1081+
features.append(spectral_centroid)
1082+
1083+
# Find peaks with robust thresholding
1084+
if np.max(lead_signal) != np.min(lead_signal): # Avoid division by zero
1085+
peak_height = 0.3 * (np.max(lead_signal) - np.min(lead_signal)) + np.min(lead_signal)
1086+
min_distance = max(int(0.2 * self.target_sf), 1) # Ensure positive distance
1087+
peaks, _ = signal.find_peaks(lead_signal, height=peak_height, distance=min_distance)
1088+
else:
1089+
peaks = []
1090+
1091+
# Heart rate features
1092+
heart_rate_features = self._calculate_heart_rate_features(lead_signal, peaks)
1093+
features.extend(heart_rate_features)
1094+
1095+
# Wavelet features
1096+
wavelet_features = self._calculate_wavelet_features(lead_signal)
1097+
features.extend(wavelet_features)
1098+
1099+
# Non-linear features
1100+
features.append(np.mean(np.abs(np.diff(lead_signal)))) # Average absolute difference
1101+
features.append(np.sqrt(np.mean(np.square(np.diff(lead_signal))))) # Root mean square of successive differences
1102+
1103+
return np.array(features)
1104+
10501105
def _calculate_heart_rate_features(self, ecg, peaks):
10511106
if len(peaks) > 1:
10521107
# Heart rate
@@ -1118,4 +1173,58 @@ def find_st_deviation(self, ecg, peaks):
11181173
return ecg[st_point] - ecg[peaks[-1]]
11191174
return 0
11201175

1176+
def signal_lead_normalization(ecg):
1177+
"""
1178+
Normalize each lead individually using z-score normalization.
1179+
"""
1180+
if ecg.shape[0] == 12:
1181+
ecg = ecg.T
1182+
transpose_back = True
1183+
else:
1184+
transpose_back = False
1185+
1186+
normalized_ecg = np.zeros_like(ecg, dtype=np.float32)
1187+
1188+
for lead_idx in range(12):
1189+
lead_signal = ecg[:, lead_idx]
1190+
lead_mean = np.mean(lead_signal)
1191+
lead_std = np.std(lead_signal) + 1e-10
1192+
normalized_ecg[:, lead_idx] = (lead_signal - lead_mean) / lead_std
1193+
1194+
if transpose_back:
1195+
normalized_ecg = normalized_ecg.T
1196+
1197+
return normalized_ecg
1198+
1199+
def feature_normalization(self, rag_features):
1200+
"""
1201+
Normalize RAG features using z-score normalization.
1202+
"""
1203+
features_per_lead = len(self.ecg_feature_list)
1204+
expected_total_features = 12 * features_per_lead
1205+
1206+
if rag_features.ndim != 1:
1207+
raise ValueError(f"Expected 1D array, got shape {rag_features.shape}")
1208+
1209+
if len(rag_features) != expected_total_features:
1210+
raise ValueError(f"Expected {expected_total_features} features for 12-lead ECG, got {len(rag_features)}")
1211+
1212+
normalized_features = np.zeros_like(rag_features, dtype=np.float32)
1213+
1214+
for feature_idx, feature_name in enumerate(self.ecg_feature_list):
1215+
feature_values = []
1216+
for lead_idx in range(12):
1217+
feature_pos = lead_idx * features_per_lead + feature_idx
1218+
feature_values.append(rag_features[feature_pos])
1219+
1220+
feature_values = np.array(feature_values)
1221+
1222+
feature_mean = np.mean(feature_values)
1223+
feature_std = np.std(feature_values) + 1e-10
1224+
1225+
for lead_idx in range(12):
1226+
feature_pos = lead_idx * features_per_lead + feature_idx
1227+
normalized_features[feature_pos] = (rag_features[feature_pos] - feature_mean) / feature_std
1228+
1229+
return normalized_features
11211230

0 commit comments

Comments
 (0)