Skip to content

Commit 3330faf

Browse files
committed
!513 refactor recursive dbfs ls and add test
1 parent 0b6a78e commit 3330faf

File tree

3 files changed

+55
-15
lines changed

3 files changed

+55
-15
lines changed

databricks_cli/dbfs/api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,18 @@ class DbfsApi(object):
9191
def __init__(self, api_client):
9292
self.client = DbfsService(api_client)
9393

94-
def list_files(self, dbfs_path, headers=None):
95-
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
94+
def _recursive_list(self, **kwargs):
95+
paths = self.client.list_files(**kwargs)
96+
files = [p for p in paths if not p.is_dir]
97+
for p in paths:
98+
files = files + self._recursive_list(p) if p.is_dir else files
99+
return files
100+
101+
def list_files(self, dbfs_path, headers=None, is_recursive=False):
102+
if is_recursive:
103+
list_response = self._recursive_list(dbfs_path, headers)
104+
else:
105+
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
96106
if 'files' in list_response:
97107
return [FileInfo.from_json(f) for f in list_response['files']]
98108
else:

databricks_cli/dbfs/cli.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@
3838
@click.option('-l', is_flag=True, default=False,
3939
help="""Displays full information including size, file type
4040
and modification time since Epoch in milliseconds.""")
41+
@click.option('--recursive', '-r', is_flag=True, default=False,
42+
help='Displays all subdirectories and files.')
4143
@click.argument('dbfs_path', nargs=-1, type=DbfsPathClickType())
4244
@debug_option
4345
@profile_option
4446
@eat_exceptions
4547
@provide_api_client
46-
def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
48+
def ls_cli(api_client, l, absolute, recursive, dbfs_path): # NOQA
4749
"""
4850
List files in DBFS.
4951
"""
@@ -53,7 +55,10 @@ def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
5355
dbfs_path = dbfs_path[0]
5456
else:
5557
error_and_quit('ls can take a maximum of one path.')
56-
files = DbfsApi(api_client).list_files(dbfs_path)
58+
59+
DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive)
60+
absolute = absolute or recursive
61+
5762
table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files],
5863
tablefmt='plain')
5964
click.echo(table)

tests/dbfs/test_api.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,26 @@
3636

3737
TEST_DBFS_PATH = DbfsPath('dbfs:/test')
3838
DUMMY_TIME = 1613158406000
39-
TEST_FILE_JSON = {
39+
TEST_FILE_JSON1 = {
4040
'path': '/test',
4141
'is_dir': False,
4242
'file_size': 1,
4343
'modification_time': DUMMY_TIME
4444
}
45-
TEST_FILE_INFO = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
45+
TEST_FILE_JSON2 = {
46+
'path': '/dir/test',
47+
'is_dir': False,
48+
'file_size': 1,
49+
'modification_time': DUMMY_TIME
50+
}
51+
TEST_DIR_JSON = {
52+
'path': '/dir',
53+
'is_dir': True,
54+
'file_size': 0,
55+
'modification_time': DUMMY_TIME
56+
}
57+
TEST_FILE_INFO0 = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
58+
TEST_FILE_INFO1 = api.FileInfo(TEST_DBFS_PATH2, False, 1, DUMMY_TIME)
4659

4760

4861
def get_resource_does_not_exist_exception():
@@ -74,7 +87,7 @@ def test_to_row_long_form_not_absolute(self):
7487
assert TEST_DBFS_PATH.basename == row[2]
7588

7689
def test_from_json(self):
77-
file_info = api.FileInfo.from_json(TEST_FILE_JSON)
90+
file_info = api.FileInfo.from_json(TEST_FILE_JSON0)
7891
assert file_info.dbfs_path == TEST_DBFS_PATH
7992
assert not file_info.is_dir
8093
assert file_info.file_size == 1
@@ -89,15 +102,26 @@ def dbfs_api():
89102

90103

91104
class TestDbfsApi(object):
105+
def test_list_files_recursive(self, dbfs_api):
106+
json = {
107+
'files': [TEST_FILE_JSON0, TEST_DIR_JSON, TEST_FILE_JSON1]
108+
}
109+
dbfs_api.client.list.return_value = json
110+
files = dbfs_api.list_files("dbfs:/")
111+
112+
assert len(files) == 2
113+
assert TEST_FILE_INFO0 == files[0]
114+
assert TEST_FILE_INFO1 == files[1]
115+
92116
def test_list_files_exists(self, dbfs_api):
93117
json = {
94-
'files': [TEST_FILE_JSON]
118+
'files': [TEST_FILE_JSON0]
95119
}
96120
dbfs_api.client.list.return_value = json
97-
files = dbfs_api.list_files(TEST_DBFS_PATH)
121+
files = dbfs_api.list_files(TEST_DBFS_PATH, is_recursive=True)
98122

99123
assert len(files) == 1
100-
assert TEST_FILE_INFO == files[0]
124+
assert TEST_FILE_INFO0 == files[0]
101125

102126
def test_list_files_does_not_exist(self, dbfs_api):
103127
json = {}
@@ -107,7 +131,7 @@ def test_list_files_does_not_exist(self, dbfs_api):
107131
assert len(files) == 0
108132

109133
def test_file_exists_true(self, dbfs_api):
110-
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
134+
dbfs_api.client.get_status.return_value = TEST_FILE_JSON0
111135
assert dbfs_api.file_exists(TEST_DBFS_PATH)
112136

113137
def test_file_exists_false(self, dbfs_api):
@@ -116,8 +140,8 @@ def test_file_exists_false(self, dbfs_api):
116140
assert not dbfs_api.file_exists(TEST_DBFS_PATH)
117141

118142
def test_get_status(self, dbfs_api):
119-
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
120-
assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO
143+
dbfs_api.client.get_status.return_value = TEST_FILE_JSON0
144+
assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO0
121145

122146
def test_get_status_fail(self, dbfs_api):
123147
exception = get_resource_does_not_exist_exception()
@@ -151,7 +175,8 @@ def test_put_large_file(self, dbfs_api, tmpdir):
151175
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True)
152176
assert api_mock.add_block.call_count == 1
153177
assert test_handle == api_mock.add_block.call_args[0][0]
154-
assert b64encode(b'test').decode() == api_mock.add_block.call_args[0][1]
178+
assert b64encode(b'test').decode(
179+
) == api_mock.add_block.call_args[0][1]
155180
assert api_mock.close.call_count == 1
156181
assert test_handle == api_mock.close.call_args[0][0]
157182

@@ -164,7 +189,7 @@ def test_get_file_check_overwrite(self, dbfs_api, tmpdir):
164189

165190
def test_get_file(self, dbfs_api, tmpdir):
166191
api_mock = dbfs_api.client
167-
api_mock.get_status.return_value = TEST_FILE_JSON
192+
api_mock.get_status.return_value = TEST_FILE_JSON0
168193
api_mock.read.return_value = {
169194
'bytes_read': 1,
170195
'data': b64encode(b'x'),

0 commit comments

Comments
 (0)