|
24 | 24 | from base64 import b64encode, b64decode |
25 | 25 |
|
26 | 26 | import os |
| 27 | +import shutil |
| 28 | +import tempfile |
| 29 | + |
27 | 30 | import click |
28 | 31 |
|
29 | 32 | from requests.exceptions import HTTPError |
@@ -221,8 +224,40 @@ def cp(self, recursive, overwrite, src, dst, headers=None): |
221 | 224 | 'To use this utility, one of the src or dst must be prefixed ' |
222 | 225 | 'with dbfs:/') |
223 | 226 | elif DbfsPath.is_valid(src) and DbfsPath.is_valid(dst): |
224 | | - error_and_quit('Both paths provided are from the DBFS filesystem. ' |
225 | | - 'To copy between the DBFS filesystem, you currently must copy the ' |
226 | | - 'file from DBFS to your local filesystem and then back.') |
| 227 | + with TempDir() as temp_dir: |
| 228 | + # Always copy to <temp_dir>/temp since this will work no matter if it's a |
| 229 | + # recursive or a non-recursive copy. |
| 230 | + temp_path = temp_dir.path('temp') |
| 231 | + self.cp(recursive, True, src, temp_path) |
| 232 | + self.cp(recursive, overwrite, temp_path, dst) |
227 | 233 | else: |
228 | 234 | assert False, 'not reached' |
| 235 | + |
| 236 | + def cat(self, src): |
| 237 | + with TempDir() as temp_dir: |
| 238 | + temp_path = temp_dir.path('temp') |
| 239 | + self.cp(False, True, src, temp_path) |
| 240 | + with open(temp_path) as f: |
| 241 | + click.echo(f.read(), nl=False) |
| 242 | + |
| 243 | + |
| 244 | +class TempDir(object): |
| 245 | + def __init__(self, remove_on_exit=True): |
| 246 | + self._dir = None |
| 247 | + self._path = None |
| 248 | + self._remove = remove_on_exit |
| 249 | + |
| 250 | + def __enter__(self): |
| 251 | + self._path = os.path.abspath(tempfile.mkdtemp()) |
| 252 | + assert os.path.exists(self._path) |
| 253 | + return self |
| 254 | + |
| 255 | + def __exit__(self, tp, val, traceback): |
| 256 | + if self._remove and os.path.exists(self._path): |
| 257 | + shutil.rmtree(self._path) |
| 258 | + |
| 259 | + assert not self._remove or not os.path.exists(self._path) |
| 260 | + assert os.path.exists(os.getcwd()) |
| 261 | + |
| 262 | + def path(self, *path): |
| 263 | + return os.path.join(self._path, *path) |
0 commit comments