Skip to content

Commit b5e49fa

Browse files
author
FelixAbrahamsson
committed
Fix linting errors/warnings and remove unused imports
1 parent 019b2bc commit b5e49fa

File tree

3 files changed

+18
-20
lines changed

3 files changed

+18
-20
lines changed

datastream/dataset.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
Tuple, Callable, Any, Union, List, TypeVar, Generic, Dict, Optional
55
)
66
from pathlib import Path
7-
from functools import partial
8-
from itertools import repeat, chain
97
import numpy as np
108
import pandas as pd
119
import torch
@@ -15,9 +13,10 @@
1513
T = TypeVar('T')
1614
R = TypeVar('R')
1715

16+
1817
class Dataset(BaseModel, torch.utils.data.Dataset, Generic[T]):
1918
'''
20-
A ``Dataset[T]`` is a mapping that allows pipelining of functions in a
19+
A ``Dataset[T]`` is a mapping that allows pipelining of functions in a
2120
readable syntax returning an item of type ``T``.
2221
2322
>>> from datastream import Dataset
@@ -67,8 +66,8 @@ def from_subscriptable(subscriptable) -> Dataset:
6766
Create ``Dataset`` based on subscriptable i.e. implements
6867
``__getitem__`` and ``__len__``. Should only be used for simple
6968
examples as a ``Dataset`` created with this method does not support
70-
methods that require a source dataframe (i.e. :func:``Dataset.split``
71-
and :func:``Dataset.subset``)
69+
methods that require a source dataframe (i.e. :func:`Dataset.split`
70+
and :func:`Dataset.subset`)
7271
'''
7372

7473
return (
@@ -138,7 +137,7 @@ def map(
138137
def subset(
139138
self, mask_fn: Callable[
140139
[pd.DataFrame], Union[pd.Series, np.array, List[bool]]
141-
]
140+
]
142141
) -> Dataset[T]:
143142
'''
144143
Select a subset of the dataset using a function that receives the
@@ -191,7 +190,7 @@ def split(
191190
safely use a seed instead of a filepath.
192191
193192
Saved splits can continue from the old split and handles:
194-
193+
195194
* New examples
196195
* Changing test size
197196
* Adapt after removing examples from dataset
@@ -441,7 +440,7 @@ def test_combine_dataset():
441440
)
442441
for index, inner_indices in enumerate(indices)
443442
)
444-
443+
445444

446445
def test_split_dataset():
447446
dataset = Dataset.from_dataframe(pd.DataFrame(dict(

datastream/datastream.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
Dict,
66
List,
77
Callable,
8-
Any,
98
Optional,
10-
Iterable,
119
TypeVar,
1210
Generic,
1311
Union,
@@ -28,6 +26,7 @@
2826
T = TypeVar('T')
2927
R = TypeVar('R')
3028

29+
3130
class Datastream(BaseModel, Generic[T]):
3231
'''
3332
``Datastream[T]`` combines a ``Dataset[T]`` and a sampler into a stream of
@@ -294,7 +293,9 @@ def ZippedMergedDatastream():
294293
(ZippedMergedDatastream(), 5),
295294
])
296295

297-
it = iter(datastream.data_loader(batch_size=16, n_batches_per_epoch=10))
296+
it = iter(datastream.data_loader(
297+
batch_size=16, n_batches_per_epoch=10
298+
))
298299
for _ in range(10):
299300
print(next(it))
300301

datastream/samplers.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
from __future__ import annotations
22
from pydantic import BaseModel
3-
from typing import Tuple, Dict, Callable, Any, Optional, Iterable
3+
from typing import Tuple, Callable, Iterable
44
from functools import partial
5-
from itertools import repeat, chain, islice
6-
from collections import namedtuple
7-
import numpy as np
8-
import pandas as pd
5+
from itertools import chain
96
import torch
10-
from datastream.tools import starcompose, star, repeat_map_chain
7+
from datastream.tools import starcompose, repeat_map_chain
118
from datastream import Dataset
129

1310

@@ -47,7 +44,7 @@ def weight(self, index):
4744

4845
def update_weights_(self, function):
4946
self.sampler.weights[:] = function(self.sampler.weights)
50-
47+
5148
def update_example_weight_(self, weight, index):
5249
if hasattr(weight, 'item'):
5350
weight = weight.item()
@@ -213,7 +210,9 @@ def update_weights_(self, function):
213210

214211
def update_example_weight_(self, weights, index):
215212
inner_indices = self.from_mapping(index)
216-
for sampler, weight, inner_index in zip(self.samplers, weights, inner_indices):
213+
for sampler, weight, inner_index in zip(
214+
self.samplers, weights, inner_indices
215+
):
217216
sampler.update_example_weight_(
218217
weight, inner_index
219218
)
@@ -318,7 +317,6 @@ def load_state_dict(self, state_dict):
318317
sampler.load_state_dict(state_dict)
319318

320319

321-
322320
class RepeatSampler(BaseModel, torch.utils.data.Sampler):
323321
sampler: torch.utils.data.Sampler
324322
length: int

0 commit comments

Comments
 (0)