Skip to content

Commit c3eb82e

Browse files
Merge pull request #161 from allenai/add-no-grad
Add torch.no_grad, fix greedy_until bug
2 parents cae37ad + 6f0389c commit c3eb82e

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
### Fixed
11+
12+
- Added `torch.no_grad()` around model calls in `language_model.py`
13+
- Prevent crashes with more robust stop token for `greedy_until` in `language_model.py`
14+
1015
## [v1.0.0rc0](https://github.com/allenai/catwalk/releases/tag/v1.0.0rc0) - 2023-12-19
1116

1217
### Added

catwalk/models/language_model.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,8 @@ def _run_loglikelihood_tokens(
567567
for field_name, tensors in unpadded_batch.items()
568568
}
569569

570-
batch_logits = log_softmax(model(**padded_batch)[0], dim=-1)
570+
with torch.no_grad():
571+
batch_logits = log_softmax(model(**padded_batch)[0], dim=-1)
571572
z = zip(
572573
batch_of_indices,
573574
batch_logits,
@@ -642,8 +643,8 @@ def _run_greedy_until(
642643
if isinstance(untils, str):
643644
untils = [untils]
644645
# if any of the stop phrases are single tokens we can use that for early termination
645-
primary_until = None
646-
for tokenized_until in tokenizer(untils)["input_ids"]:
646+
primary_until = tokenizer.eos_token_id
647+
for tokenized_until in tokenizer(untils, add_special_tokens=False)["input_ids"]:
647648
if len(tokenized_until) == 1:
648649
primary_until = tokenized_until[0]
649650

@@ -652,13 +653,14 @@ def _run_greedy_until(
652653
[tokenized_context[max_gen_toks - model_max_length :]]
653654
).to(model.device)
654655

655-
full_text_tensor = model.generate(
656-
context_tensor,
657-
max_length=context_tensor.shape[1] + max_gen_toks,
658-
eos_token_id=primary_until,
659-
do_sample=False,
660-
pad_token_id=primary_until, # temporary hack to suppress irrelevant warning until batch processing is added
661-
)
656+
with torch.no_grad():
657+
full_text_tensor = model.generate(
658+
context_tensor,
659+
max_length=context_tensor.shape[1] + max_gen_toks,
660+
eos_token_id=primary_until,
661+
do_sample=False,
662+
pad_token_id=primary_until, # temporary hack to suppress irrelevant warning until batch processing is added
663+
)
662664
continuation_tensor = full_text_tensor[0, context_tensor.shape[1] :]
663665
continuation = tokenizer.decode(continuation_tensor.tolist())
664666
raw_continuation = continuation

0 commit comments

Comments
 (0)