Skip to content

Commit 6b8b683

Browse files
authored
support dict batch in dynabert (PaddlePaddle#1064)
1 parent bd4ecf0 commit 6b8b683

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

paddleslim/nas/ofa/utils/nlp_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def compute_neuron_head_importance(task_name,
7070
data_loader = (data_loader, )
7171
for data in data_loader:
7272
for batch in data:
73-
input_ids, segment_ids, labels = batch
73+
if isinstance(batch, dict):
74+
input_ids, segment_ids, labels = batch['input_ids'], batch[
75+
'token_type_ids'], batch['labels']
76+
else:
77+
input_ids, segment_ids, labels = batch
7478
logits = model(
7579
input_ids, segment_ids, attention_mask=[None, head_mask])
7680
loss = loss_fct(logits, labels)

0 commit comments

Comments
 (0)