Skip to content

Commit 3fe24ff

Browse files
Fix output transform, add test to enforce tokenizer consistency (#73)
*Description of changes:* The bin indexes were shifted by one between input transform and output transform. Subtracting 1 to the sampled tokens in output transform lead to the correct reconstruction of the signal. Add a test to ensure the consistency of the Chronos Tokenizer. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. Co-authored-by: Lorenzo Stella <stellalo@amazon.com> and Abdul Fatir Ansari <ansarnd@amazon.com>
1 parent 02d1a1d commit 3fe24ff

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

src/chronos/chronos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def output_transform(
185185
) -> torch.Tensor:
186186
scale_unsqueezed = scale.unsqueeze(-1).unsqueeze(-1)
187187
indices = torch.clamp(
188-
samples - self.config.n_special_tokens,
188+
samples - self.config.n_special_tokens - 1,
189189
min=0,
190190
max=len(self.centers) - 1,
191191
)

test/test_chronos.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,45 @@
77
import torch
88
import pytest
99

10-
from chronos import ChronosConfig, ChronosPipeline
10+
from chronos import ChronosConfig, ChronosPipeline, MeanScaleUniformBins
11+
12+
13+
@pytest.mark.parametrize("n_numerical_tokens", [5, 10, 27])
14+
@pytest.mark.parametrize("n_special_tokens", [2, 5, 13])
15+
def test_tokenizer_consistency(n_numerical_tokens: int, n_special_tokens: int):
16+
n_tokens = n_numerical_tokens + n_special_tokens
17+
18+
config = ChronosConfig(
19+
tokenizer_class="MeanScaleUniformBins",
20+
tokenizer_kwargs=dict(low_limit=-1.0, high_limit=1.0),
21+
n_tokens=n_tokens,
22+
n_special_tokens=n_special_tokens,
23+
pad_token_id=0,
24+
eos_token_id=1,
25+
use_eos_token=True,
26+
model_type="seq2seq",
27+
context_length=512,
28+
prediction_length=64,
29+
num_samples=20,
30+
temperature=1.0,
31+
top_k=50,
32+
top_p=1.0,
33+
)
34+
35+
tokenizer = config.create_tokenizer()
36+
assert isinstance(tokenizer, MeanScaleUniformBins)
37+
38+
context = tokenizer.centers.unsqueeze(0) # add batch dimension
39+
scale = torch.ones((1,)) # fix the scale to one to turn off scaling
40+
41+
token_ids, _, _ = tokenizer.input_transform(context, scale=scale)
42+
43+
samples = tokenizer.output_transform(
44+
token_ids[:, :-1].unsqueeze(1), # remove final EOS, add sample dimension
45+
scale=scale,
46+
)
47+
48+
assert (samples[0, 0, :] == context).all()
1149

1250

1351
@pytest.mark.xfail

0 commit comments

Comments
 (0)