Skip to content

Commit 3f483df

Browse files
Qualcomm AI Engine Direct - Fix broken unpacking in T5 dataset loading (#13625)
### Summary Fix broken unpacking in T5 dataset loading ### Test plan ``` bash python backends/qualcomm/tests/test_qnn_delegate.py TestExampleOssScript.test_t5 -s ${device_id} -m ${soc} --build_folder build-android/ --executorch_root . --artifact_dir . --qa_dataset ${path_to_SQuAD-v1.1.csv} ```
1 parent b743cc1 commit 3f483df

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5710,7 +5710,7 @@ def test_t5(self):
57105710
"python",
57115711
f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py",
57125712
"--dataset",
5713-
self.sentence_dataset,
5713+
self.qa_dataset,
57145714
"--artifact",
57155715
self.artifact_dir,
57165716
"--build_folder",
@@ -6577,6 +6577,11 @@ def setup_environment():
65776577
help="Location for imagenet dataset",
65786578
type=str,
65796579
)
6580+
parser.add_argument(
6581+
"--qa_dataset",
6582+
help="Location for QA dataset",
6583+
type=str,
6584+
)
65806585
parser.add_argument(
65816586
"--sentence_dataset",
65826587
help="Location for sentence dataset",
@@ -6640,6 +6645,7 @@ def setup_environment():
66406645
TestQNN.executorch_root = args.executorch_root
66416646
TestQNN.artifact_dir = args.artifact_dir
66426647
TestQNN.image_dataset = args.image_dataset
6648+
TestQNN.qa_dataset = args.qa_dataset
66436649
TestQNN.sentence_dataset = args.sentence_dataset
66446650
TestQNN.pretrained_weight = args.pretrained_weight
66456651
TestQNN.model_name = args.model_name

examples/qualcomm/utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def __len__(self):
637637
# prepare input data
638638
inputs, targets = [], []
639639
data_loader = get_data_loader()
640-
for _, data in enumerate(data_loader):
640+
for data in data_loader:
641641
if len(inputs) >= data_size:
642642
break
643643
input_ids = data[0]
@@ -729,9 +729,9 @@ def __getitem__(self, idx):
729729
dataset, batch_size=1, shuffle=shuffle, collate_fn=collator
730730
)
731731

732-
inputs, targets, input_list = [], [], ""
732+
inputs, targets = [], []
733733
data_loader = get_data_loader(max_hidden_seq_length)
734-
for idx, batch in enumerate(data_loader):
734+
for batch in data_loader:
735735
if len(inputs) >= data_size:
736736
break
737737
input_ids = batch["input_ids"]
@@ -750,9 +750,8 @@ def __getitem__(self, idx):
750750
)
751751
)
752752
targets.append(labels)
753-
input_list += f"input_{idx}_0.raw input_{idx}_1.raw input_{idx}_2.raw\n"
754753

755-
return inputs, targets, input_list
754+
return inputs, targets
756755

757756

758757
def setup_common_args_and_variables():

0 commit comments

Comments
 (0)