Skip to content

Commit c109c3a

Browse files
Shibo Xingfacebook-github-bot
authored andcommitted
Fix test_remote_io.py due to mutating public s3 bucket (#997)
Summary: Please read through our [contribution guide](https://github.com/pytorch/data/blob/main/CONTRIBUTING.md) prior to creating your pull request. - Note that there is a section on requirements related to adding a new DataPipe. Fixes #984 ### Changes - Add a private function to TestDataPipeRemoteIO in test_remote_io.py to get s3 objects count label through aws cli - add awscli in requirements Pull Request resolved: #997 Reviewed By: ejguan Differential Revision: D43157757 Pulled By: NivekT fbshipit-source-id: 7e9ee8299a28a087f88024c3b3e77be3bfe5adf0
1 parent b450cfd commit c109c3a

File tree

2 files changed

+60
-49
lines changed

2 files changed

+60
-49
lines changed

test/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ protobuf >= 3.9.2, < 3.20
1313
datasets
1414
graphviz
1515
adlfs
16+
awscli>=1.27.66

test/test_remote_io.py

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import io
8+
import json
89
import os
10+
import subprocess
911
import unittest
1012
import warnings
1113
from unittest.mock import patch
@@ -223,24 +225,49 @@ def _filepath_fn(url):
223225
res = list(dl)
224226
self.assertEqual(sorted(expected), sorted(res))
225227

228+
def __get_s3_cnt(self, s3_pths: list, recursive=True):
229+
"""Return the count of the total objects collected from a list s3 paths"""
230+
tot_objs = set()
231+
for p in s3_pths:
232+
pth_parts = p.split("s3://")[1].split("/", 1)
233+
if len(pth_parts) == 1:
234+
bkt_name, prefix = pth_parts[0], ""
235+
else:
236+
bkt_name, prefix = pth_parts
237+
238+
aws_cmd = f"aws --output json s3api list-objects --bucket {bkt_name} --no-sign-request"
239+
if prefix.strip():
240+
aws_cmd += f" --prefix {prefix}"
241+
if not recursive:
242+
aws_cmd += " --delimiter /"
243+
244+
res = subprocess.run(aws_cmd, shell=True, check=True, capture_output=True)
245+
json_res = json.loads(res.stdout)
246+
if "Contents" in json_res:
247+
objs = [v["Key"] for v in json_res["Contents"]]
248+
else:
249+
objs = [v["Prefix"] for v in json_res["CommonPrefixes"]]
250+
tot_objs |= set(objs)
251+
252+
return len(tot_objs)
253+
226254
@skipIfNoFSSpecS3
227255
def test_fsspec_io_iterdatapipe(self):
228256
input_list = [
229-
(["s3://ai2-public-datasets"], 41), # bucket without '/'
230-
(["s3://ai2-public-datasets/charades/"], 18), # bucket with '/'
231-
(
232-
[
233-
"s3://ai2-public-datasets/charades/Charades_v1.zip",
234-
"s3://ai2-public-datasets/charades/Charades_v1_flow.tar",
235-
"s3://ai2-public-datasets/charades/Charades_v1_rgb.tar",
236-
"s3://ai2-public-datasets/charades/Charades_v1_480.zip",
237-
],
238-
4,
239-
), # multiple files
257+
["s3://ai2-public-datasets"], # bucket without '/'
258+
["s3://ai2-public-datasets/charades/"], # bucket with '/'
259+
[
260+
"s3://ai2-public-datasets/charades/Charades_v1.zip",
261+
"s3://ai2-public-datasets/charades/Charades_v1_flow.tar",
262+
"s3://ai2-public-datasets/charades/Charades_v1_rgb.tar",
263+
"s3://ai2-public-datasets/charades/Charades_v1_480.zip",
264+
], # multiple files
240265
]
241-
for urls, num in input_list:
266+
for urls in input_list:
242267
fsspec_lister_dp = FSSpecFileLister(IterableWrapper(urls), anon=True)
243-
self.assertEqual(sum(1 for _ in fsspec_lister_dp), num, f"{urls} failed")
268+
self.assertEqual(
269+
sum(1 for _ in fsspec_lister_dp), self.__get_s3_cnt(urls, recursive=False), f"{urls} failed"
270+
)
244271

245272
url = "s3://ai2-public-datasets/charades/"
246273
fsspec_loader_dp = FSSpecFileOpener(FSSpecFileLister(IterableWrapper([url]), anon=True), anon=True)
@@ -276,42 +303,33 @@ def test_disabled_s3_io_iterdatapipe(self):
276303
def test_s3_io_iterdatapipe(self):
277304
# S3FileLister: different inputs
278305
input_list = [
279-
[["s3://ai2-public-datasets"], 81], # bucket without '/'
280-
[["s3://ai2-public-datasets/"], 81], # bucket with '/'
281-
[["s3://ai2-public-datasets/charades"], 18], # folder without '/'
282-
[["s3://ai2-public-datasets/charades/"], 18], # folder without '/'
283-
[["s3://ai2-public-datasets/charad"], 18], # prefix
306+
["s3://ai2-public-datasets"], # bucket without '/'
307+
["s3://ai2-public-datasets/"], # bucket with '/'
308+
["s3://ai2-public-datasets/charades"], # folder without '/'
309+
["s3://ai2-public-datasets/charades/"], # folder without '/'
310+
["s3://ai2-public-datasets/charad"], # prefix
284311
[
285-
[
286-
"s3://ai2-public-datasets/charades/Charades_v1",
287-
"s3://ai2-public-datasets/charades/Charades_vu17",
288-
],
289-
12,
312+
"s3://ai2-public-datasets/charades/Charades_v1",
313+
"s3://ai2-public-datasets/charades/Charades_vu17",
290314
], # prefixes
291-
[["s3://ai2-public-datasets/charades/Charades_v1.zip"], 1], # single file
315+
["s3://ai2-public-datasets/charades/Charades_v1.zip"], # single file
292316
[
293-
[
294-
"s3://ai2-public-datasets/charades/Charades_v1.zip",
295-
"s3://ai2-public-datasets/charades/Charades_v1_flow.tar",
296-
"s3://ai2-public-datasets/charades/Charades_v1_rgb.tar",
297-
"s3://ai2-public-datasets/charades/Charades_v1_480.zip",
298-
],
299-
4,
317+
"s3://ai2-public-datasets/charades/Charades_v1.zip",
318+
"s3://ai2-public-datasets/charades/Charades_v1_flow.tar",
319+
"s3://ai2-public-datasets/charades/Charades_v1_rgb.tar",
320+
"s3://ai2-public-datasets/charades/Charades_v1_480.zip",
300321
], # multiple files
301322
[
302-
[
303-
"s3://ai2-public-datasets/charades/Charades_v1.zip",
304-
"s3://ai2-public-datasets/charades/Charades_v1_flow.tar",
305-
"s3://ai2-public-datasets/charades/Charades_v1_rgb.tar",
306-
"s3://ai2-public-datasets/charades/Charades_v1_480.zip",
307-
"s3://ai2-public-datasets/charades/Charades_vu17",
308-
],
309-
10,
323+
"s3://ai2-public-datasets/charades/Charades_v1.zip",
324+
"s3://ai2-public-datasets/charades/Charades_v1_flow.tar",
325+
"s3://ai2-public-datasets/charades/Charades_v1_rgb.tar",
326+
"s3://ai2-public-datasets/charades/Charades_v1_480.zip",
327+
"s3://ai2-public-datasets/charades/Charades_vu17",
310328
], # files + prefixes
311329
]
312330
for input in input_list:
313-
s3_lister_dp = S3FileLister(IterableWrapper(input[0]), region="us-west-2")
314-
self.assertEqual(sum(1 for _ in s3_lister_dp), input[1], f"{input[0]} failed")
331+
s3_lister_dp = S3FileLister(IterableWrapper(input), region="us-west-2")
332+
self.assertEqual(sum(1 for _ in s3_lister_dp), self.__get_s3_cnt(input), f"{input} failed")
315333

316334
# S3FileLister: prefixes + different region
317335
file_urls = [
@@ -334,14 +352,6 @@ def test_s3_io_iterdatapipe(self):
334352
for _ in s3_lister_dp:
335353
pass
336354

337-
# S3FileLoader: loader
338-
input = [
339-
"s3://charades-tar-shards/charades-video-0.tar",
340-
"s3://charades-tar-shards/charades-video-1.tar",
341-
] # multiple files
342-
s3_loader_dp = S3FileLoader(input, region="us-west-2")
343-
self.assertEqual(sum(1 for _ in s3_loader_dp), 2, f"{input} failed")
344-
345355
input = [["s3://aft-vbi-pds/bin-images/100730.jpg"], 1]
346356
s3_loader_dp = S3FileLoader(input[0], region="us-east-1")
347357
self.assertEqual(sum(1 for _ in s3_loader_dp), input[1], f"{input[0]} failed")

0 commit comments

Comments
 (0)