Skip to content

Commit b7ce9d9

Browse files
committed
update data loader with old mapping names
1 parent 3d1de89 commit b7ce9d9

File tree

3 files changed

+32
-26
lines changed

3 files changed

+32
-26
lines changed

ecg_bench/scripts/train_1st.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/bin/bash
22

33
# models=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit")
4-
models=("merl")
4+
models=("siglip")
55
# data=("ecg-qa-mimic-iv-ecg-250-1250")
66
# data=("ecg_instruct_45k_mapped_1250")
77

@@ -10,7 +10,7 @@ for model in "${models[@]}"; do
1010
python main.py \
1111
--data=ecg-qa_mimic-iv-ecg_mapped_1250 \
1212
--model=$model \
13-
--device=cuda:0 \
13+
--device=cuda:4 \
1414
--train=first \
1515
--batch_size=64 \
1616
--seg_len=1250 \

ecg_bench/scripts/train_2nd.sh

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#!/usr/bin/env bash
22
# ------------------- CONFIGURABLE LISTS -------------------
3-
encoders=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit")
4-
encoders_checkpoints=("stmem_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "merl_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "mlae_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None" "mtae_256_50_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None")
5-
llms=("gemma-2-2b-it" "llama-3.2-1b-instruct" "qwen2.5-1.5b-instruct")
6-
datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here
3+
# encoders=("stmem" "merl" "mlae" "mtae" "siglip" "clip" "vit")
4+
encoders=("merl")
5+
encoders_checkpoints=("merl_adam_64_50_1024_0.0001_0.9_0.99_1e-08_500_0.01_True_None_None_None_None_1_None_None_False")
6+
# llms=("gemma-2-2b-it" "llama-3.2-1b-instruct" "qwen2.5-1.5b-instruct")
7+
llms=("llama-3.2-1b-instruct")
8+
# datasets=("ecg-qa_ptbxl-250-1250" "ecg-qa-mimic-iv-ecg-250-1250" "ecg-instruct-45k-250-1250" "ecg-instruct-pulse-250-1250" "pretrain-mimic-250-1250") # add more datasets here
9+
datasets=("ecg_instruct_45k_mapped_1250")
710
# ----------------------------------------------------------
811

912
for data in "${datasets[@]}"; do
@@ -26,7 +29,7 @@ for data in "${datasets[@]}"; do
2629
python main.py \
2730
--data="$data" \
2831
--model="${encoder}_${llm}" \
29-
--device=cuda:7 \
32+
--device=cuda:3 \
3033
--train=second \
3134
--batch_size=2 \
3235
--seg_len=1250 \
@@ -37,25 +40,25 @@ for data in "${datasets[@]}"; do
3740
--attn_implementation=flash_attention_2 \
3841
--system_prompt=./data/system_prompt_e2e.txt \
3942
$([ -n "$checkpoint_path" ] && echo "--encoder_checkpoint=$checkpoint_path") \
40-
--dev
43+
--log
4144
done
4245
done
4346
done
4447

4548

46-
models=("vit" "clip" "siglip" )
49+
# models=("merl")
4750

48-
for model in "${models[@]}"; do
49-
python main.py \
50-
--data=ecg-qa_mimic-iv-ecg_mapped_1250 \
51-
--model=$model \
52-
--device=cuda:6 \
53-
--train=first \
54-
--batch_size=8 \
55-
--seg_len=1250 \
56-
--epochs=2 \
57-
--instance_normalize \
58-
--attn_implementation=flash_attention_2 \
59-
--image \
60-
--log
61-
done
51+
# for model in "${models[@]}"; do
52+
# python main.py \
53+
# --data=ecg-qa_mimic-iv-ecg_mapped_1250 \
54+
# --model=$model \
55+
# --device=cuda:6 \
56+
# --train=first \
57+
# --batch_size=8 \
58+
# --seg_len=1250 \
59+
# --epochs=2 \
60+
# --instance_normalize \
61+
# --attn_implementation=flash_attention_2 \
62+
# --image \
63+
# --log
64+
# done

ecg_bench/utils/data_loader_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ def create_position_ids(self, padded_sequence):
8181
return position_ids
8282

8383
def get_qa(self, altered_text):
84-
if self.args.data == f"pretrain-mimic-{self.args.target_sf}-{self.args.seg_len}":
84+
# if self.args.data == f"pretrain-mimic-{self.args.target_sf}-{self.args.seg_len}":
85+
if self.args.data == f"pretrain_mimic_mapped_{self.args.seg_len}":
8586
question, answer = altered_text[0]["value"].replace("\n", "").replace("<ecg>", ""), altered_text[1]["value"]
86-
elif self.args.data in [f"ecg-qa-mimic-iv-ecg-{self.args.target_sf}-{self.args.seg_len}", f"ecg-qa-ptbxl-{self.args.target_sf}-{self.args.seg_len}"]:
87+
# elif self.args.data in [f"ecg-qa-mimic-iv-ecg-{self.args.target_sf}-{self.args.seg_len}", f"ecg-qa-ptbxl-{self.args.target_sf}-{self.args.seg_len}"]:
88+
elif self.args.data in [f"ecg-qa_mimic-iv-ecg_mapped_{self.args.seg_len}", f"ecg-qa_ptbxl_mapped_{self.args.seg_len}"]:
8789
question_type, question, answer = altered_text[0], altered_text[1], altered_text[2]
8890
answer = " ".join(answer) if isinstance(answer, list) else answer
8991
return question, answer
@@ -128,7 +130,8 @@ def setup_conversation_template(self, signal = None):
128130
return conv
129131

130132
def process_altered_text(self, altered_text):
131-
if self.args.data not in [f"ecg-instruct-45k-{self.args.target_sf}-{self.args.seg_len}",
133+
if self.args.data not in [#f"ecg-instruct-45k-{self.args.target_sf}-{self.args.seg_len}",
134+
f"ecg_instruct_45k_mapped_{self.args.seg_len}",
132135
f"ecg-instruct-pulse-{self.args.target_sf}-{self.args.seg_len}",
133136
f"ecg-bench-pulse-{self.args.target_sf}-{self.args.seg_len}"]:
134137
question, answer = self.get_qa(altered_text)

0 commit comments

Comments
 (0)