Skip to content

Commit 6106799

Browse files
committed
update data_loader by adding self.args.rag
1 parent 521228e commit 6106799

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

ecg_bench/scripts/org_results.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ data=("ecg_instruct_45k_mapped_1250")
99
# normalized_rag_features=True
1010

1111
checkpoints=(
12-
'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_combined_combined_1_system_prompt_True_False'
12+
'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_combined_1_system_prompt_True_False'
13+
'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_feature_combined_1_system_prompt_None_False'
14+
'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_signal_combined_1_system_prompt_True_False'
15+
'llama-3.2-3b-instruct_2_1_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_True_signal_combined_1_system_prompt_None_False'
1316
)
1417

1518
for d in "${data[@]}"; do

ecg_bench/utils/data_loader_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def setup_conversation_template(self, signal = None):
107107
elif 'gemma' in self.args.model:
108108
conv = get_conv_template('gemma')
109109
feature=None
110-
if self.args.retrieval_base in ['feature', 'combined']:
110+
if self.args.rag and self.args.retrieval_base in ['feature', 'combined']:
111111
original_feature=self.rag_db.feature_extractor.extract_rag_features(signal)
112112
feature=original_feature
113113
if self.args.normalized_rag_feature:
@@ -146,7 +146,7 @@ def append_messages_to_conv(self, conv, altered_text, signal=None):
146146
message_value = message_value.replace('<image>', '')
147147
message_value = message_value.replace('<ecg>', '')
148148
message_value = message_value.replace('image', 'signal').replace('Image', 'Signal')
149-
if self.args.retrieval_base in ['feature', 'combined'] or self.args.retrieved_information in ['feature','combined']:
149+
if self.args.rag and (self.args.retrieval_base in ['feature', 'combined'] or self.args.retrieved_information in ['feature','combined']):
150150
original_feature=self.rag_db.feature_extractor.extract_rag_features(signal)
151151
feature=original_feature
152152
if self.args.normalized_rag_feature:

0 commit comments

Comments
 (0)