Skip to content

Commit 3c434d8

Browse files
lostellaJasper
andauthored
Backports for v0.14.2 (#3063)
* Fix `iterable.Cached`. (#3060) * Torch: Remove double caching of dataset. (#3061) --------- Co-authored-by: Jasper <schjaspe@amazon.de>
1 parent 536465d commit 3c434d8

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

src/gluonts/itertools.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,9 @@ def split_into(xs: Sequence, n: int) -> Sequence:
305305
@dataclass
306306
class Cached:
307307
"""
308-
An iterable wrapper, which caches values in a list the first time it is
309-
iterated.
308+
An iterable wrapper, which caches values in a list while iterated.
310309
311-
The primary use-case for this is to avoid re-computing the element of the
310+
The primary use-case for this is to avoid re-computing the elements of the
312311
sequence, in case the inner iterable does it on demand.
313312
314313
This should be used to wrap deterministic iterables, i.e. iterables where
@@ -317,15 +316,21 @@ class Cached:
317316
"""
318317

319318
iterable: SizedIterable
320-
cache: list = field(default_factory=list, init=False)
319+
provider: Iterable = field(init=False)
320+
consumed: list = field(default_factory=list, init=False)
321+
322+
def __post_init__(self):
323+
# ensure we only iterate once over the iterable
324+
self.provider = iter(self.iterable)
321325

322326
def __iter__(self):
323-
if not self.cache:
324-
for element in self.iterable:
325-
yield element
326-
self.cache.append(element)
327-
else:
328-
yield from self.cache
327+
# Yield already provided values first
328+
yield from self.consumed
329+
330+
# Now yield remaining elements.
331+
for element in self.provider:
332+
self.consumed.append(element)
333+
yield element
329334

330335
def __len__(self) -> int:
331336
return len(self.iterable)

src/gluonts/torch/model/estimator.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
from typing import NamedTuple, Optional, Iterable, Dict, Any, Union
14+
from typing import NamedTuple, Optional, Iterable, Dict, Any
1515
import logging
1616

1717
import numpy as np
@@ -24,7 +24,7 @@
2424
from gluonts.itertools import Cached
2525
from gluonts.model import Estimator, Predictor
2626
from gluonts.torch.model.predictor import PyTorchPredictor
27-
from gluonts.transform import Transformation, TransformedDataset
27+
from gluonts.transform import Transformation
2828

2929
logger = logging.getLogger(__name__)
3030

@@ -156,18 +156,16 @@ def train_model(
156156
transformation = self.create_transformation()
157157

158158
with env._let(max_idle_transforms=max(len(training_data), 100)):
159-
transformed_training_data: Union[
160-
Cached, TransformedDataset
161-
] = transformation.apply(training_data, is_train=True)
159+
transformed_training_data: Dataset = transformation.apply(
160+
training_data, is_train=True
161+
)
162162
if cache_data:
163163
transformed_training_data = Cached(transformed_training_data)
164164

165165
training_network = self.create_lightning_module()
166166

167167
training_data_loader = self.create_training_data_loader(
168-
Cached(transformed_training_data)
169-
if cache_data
170-
else transformed_training_data,
168+
transformed_training_data,
171169
training_network,
172170
shuffle_buffer_length=shuffle_buffer_length,
173171
)
@@ -176,9 +174,9 @@ def train_model(
176174

177175
if validation_data is not None:
178176
with env._let(max_idle_transforms=max(len(validation_data), 100)):
179-
transformed_validation_data: Union[
180-
Cached, TransformedDataset
181-
] = transformation.apply(validation_data, is_train=True)
177+
transformed_validation_data: Dataset = transformation.apply(
178+
validation_data, is_train=True
179+
)
182180
if cache_data:
183181
transformed_validation_data = Cached(
184182
transformed_validation_data

test/test_itertools.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ def test_pickle(iterable: Iterable, assert_content: bool):
119119
assert data == data_copy
120120

121121

122+
def test_cached_reentry():
123+
data = Cached(range(10))
124+
125+
assert len(data) == 10
126+
assert list(take(5, data)) == list(range(5))
127+
assert len(data) == 10
128+
assert list(take(10, data)) == list(range(10))
129+
assert len(data) == 10
130+
131+
122132
@pytest.mark.parametrize(
123133
"given, expected",
124134
[

0 commit comments

Comments
 (0)