Skip to content

Commit 6e1dedf

Browse files
committed
Fix custom ref energies shape
1 parent d5afa35 commit 6e1dedf

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def run(args):
490490
custom_refs = custom_refs.to(device)
491491

492492
# Set the custom reference energies
493-
model.heads["energy"].reference.linear.weight.data = custom_refs.unsqueeze(0)
493+
model.heads["energy"].reference.linear.weight.data = custom_refs
494494

495495
# Log some values for verification
496496
logging.info("Custom reference energies set:")

0 commit comments

Comments
 (0)