Skip to content

Commit 910e3b6

Browse files
committed
100% test coverage on wr.torch
1 parent 85bfade commit 910e3b6

File tree

4 files changed

+16
-19
lines changed

4 files changed

+16
-19
lines changed

.github/workflows/static-checking.yml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,8 @@ jobs:
2424
uses: actions/setup-python@v1
2525
with:
2626
python-version: ${{ matrix.python-version }}
27-
- name: Install dependencies
28-
run: |
29-
python -m pip install --upgrade pip
30-
pip install -r requirements.txt
31-
pip install -r requirements-dev.txt
32-
pip install -r requirements-torch.txt
27+
- name: Setup Environment
28+
run: ./setup-dev-env.sh
3329
- name: CloudFormation Lint
3430
run: cfn-lint -t testing/cloudformation.yaml
3531
- name: Documentation Lint

awswrangler/torch.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import os
55
import pathlib
66
import re
7-
import tarfile
87
from collections.abc import Iterable
98
from io import BytesIO
109
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
@@ -64,12 +63,12 @@ def _fetch_data(self, path: str) -> Any:
6463
def _load_data(data: io.BytesIO, path: str) -> Any:
6564
if path.endswith(".pt"):
6665
data = torch.load(data)
67-
elif path.endswith(".tar.gz") or path.endswith(".tgz"):
68-
tarfile.open(fileobj=data)
66+
elif path.endswith(".tar.gz") or path.endswith(".tgz"): # pragma: no cover
6967
raise NotImplementedError("Tar loader not implemented!")
68+
# tarfile.open(fileobj=data)
7069
# tar = tarfile.open(fileobj=data)
7170
# for member in tar.getmembers():
72-
else:
71+
else: # pragma: no cover
7372
raise NotImplementedError()
7473

7574
return data
@@ -86,10 +85,10 @@ def __getitem__(self, index):
8685
def __len__(self):
8786
return len(self._paths)
8887

89-
def _data_fn(self, data) -> Any:
88+
def _data_fn(self, data) -> Any: # pragma: no cover
9089
raise NotImplementedError()
9190

92-
def _label_fn(self, path: str) -> Any:
91+
def _label_fn(self, path: str) -> Any: # pragma: no cover
9392
raise NotImplementedError()
9493

9594

@@ -100,7 +99,7 @@ def _label_fn(self, path: str) -> torch.Tensor:
10099
label = int(re.findall(r"/(.*?)=(.*?)/", path)[-1][1])
101100
return torch.tensor([label]) # pylint: disable=not-callable
102101

103-
def _data_fn(self, data) -> Any:
102+
def _data_fn(self, data) -> Any: # pragma: no cover
104103
raise NotImplementedError()
105104

106105

@@ -383,9 +382,8 @@ def __iter__(self) -> Union[Iterator[torch.Tensor], Iterator[Tuple[torch.Tensor,
383382
pass
384383
elif isinstance(data, Iterable) and all([isinstance(d, torch.Tensor) for d in data]):
385384
data = zip(*data)
386-
else:
385+
else: # pragma: no cover
387386
raise NotImplementedError(f"ERROR: Type: {type(data)} has not been implemented!")
388-
389387
for d in data:
390388
yield d
391389

@@ -436,7 +434,7 @@ def __init__(
436434
def __iter__(self) -> Union[Iterator[torch.Tensor], Iterator[Tuple[torch.Tensor, torch.Tensor]]]:
437435
"""Iterate over the Dataset."""
438436
if torch.utils.data.get_worker_info() is not None: # type: ignore
439-
raise NotImplementedError()
437+
raise NotImplementedError() # pragma: no cover
440438
db._validate_engine(con=self._con) # pylint: disable=protected-access
441439
with self._con.connect() as con:
442440
cursor: Any = con.execute(self._sql)

testing/test_awswrangler/test_data_lake.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def test_parquet_validate_schema(bucket, database):
708708
df2 = pd.DataFrame({"id2": [1, 2, 3], "val": ["foo", "boo", "bar"]})
709709
path_file2 = f"s3://{bucket}/test_parquet_file_validate/1.parquet"
710710
wr.s3.to_parquet(df=df2, path=path_file2)
711-
wr.s3.wait_objects_exist(paths=[path_file2])
711+
wr.s3.wait_objects_exist(paths=[path_file2], use_threads=False)
712712
df3 = wr.s3.read_parquet(path=path, validate_schema=False)
713713
assert len(df3.index) == 6
714714
assert len(df3.columns) == 3

testing/test_awswrangler/test_torch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def test_torch_sql(parameters, db_type, chunksize):
8484

8585
@pytest.mark.parametrize("chunksize", [None, 1, 10])
8686
@pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"])
87-
def test_torch_sql_label(parameters, db_type, chunksize):
87+
@pytest.mark.parametrize("label_col", [2, "c"])
88+
def test_torch_sql_label(parameters, db_type, chunksize, label_col):
8889
schema = parameters[db_type]["schema"]
8990
table = f"test_torch_sql_label_{db_type}_{str(chunksize).lower()}"
9091
engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}")
@@ -99,7 +100,9 @@ def test_torch_sql_label(parameters, db_type, chunksize):
99100
chunksize=None,
100101
method=None,
101102
)
102-
ts = list(wr.torch.SQLDataset(f"SELECT * FROM {schema}.{table}", con=engine, chunksize=chunksize, label_col=2))
103+
ts = list(
104+
wr.torch.SQLDataset(f"SELECT * FROM {schema}.{table}", con=engine, chunksize=chunksize, label_col=label_col)
105+
)
103106
assert torch.all(ts[0][0].eq(torch.tensor([1.0, 4.0])))
104107
assert torch.all(ts[0][1].eq(torch.tensor([7], dtype=torch.long)))
105108
assert torch.all(ts[1][0].eq(torch.tensor([2.0, 5.0])))

0 commit comments

Comments
 (0)