|
18 | 18 | package org.apache.toree.magic.builtin |
19 | 19 |
|
20 | 20 | import java.io.{File, PrintStream} |
21 | | -import java.net.URL |
| 21 | +import java.net.{URL, URI} |
22 | 22 | import java.nio.file.{Files, Paths} |
23 | | - |
24 | 23 | import org.apache.toree.magic._ |
25 | 24 | import org.apache.toree.magic.builtin.AddJar._ |
26 | 25 | import org.apache.toree.magic.dependencies._ |
27 | 26 | import org.apache.toree.utils.{ArgumentParsingSupport, DownloadSupport, LogLike, FileUtils} |
28 | 27 | import com.typesafe.config.Config |
| 28 | +import org.apache.hadoop.fs.Path |
29 | 29 | import org.apache.toree.plugins.annotations.Event |
30 | 30 |
|
31 | 31 | object AddJar { |
| 32 | + val HADOOP_FS_SCHEMES = Set("hdfs", "s3", "s3n", "file") |
32 | 33 |
|
33 | 34 | private var jarDir:Option[String] = None |
34 | 35 |
|
@@ -63,18 +64,18 @@ class AddJar |
63 | 64 | private def printStream = new PrintStream(outputStream) |
64 | 65 |
|
65 | 66 | /** |
66 | | - * Retrieves file name from URL. |
| 67 | + * Retrieves file name from a URI. |
67 | 68 | * |
68 | | - * @param location The remote location (URL) |
69 | | - * @return The name of the remote URL, or an empty string if one does not exist |
| 69 | + * @param location a URI |
| 70 | + * @return The file name of the remote URI, or an empty string if one does not exist |
70 | 71 | */ |
71 | 72 | def getFileFromLocation(location: String): String = { |
72 | | - val url = new URL(location) |
73 | | - val file = url.getFile.split("/") |
74 | | - if (file.length > 0) { |
75 | | - file.last |
| 73 | + val uri = new URI(location) |
| 74 | + val pathParts = uri.getPath.split("/") |
| 75 | + if (pathParts.nonEmpty) { |
| 76 | + pathParts.last |
76 | 77 | } else { |
77 | | - "" |
| 78 | + "" |
78 | 79 | } |
79 | 80 | } |
80 | 81 |
|
@@ -122,10 +123,27 @@ class AddJar |
122 | 123 | // Report beginning of download |
123 | 124 | printStream.println(s"Starting download from $jarRemoteLocation") |
124 | 125 |
|
125 | | - downloadFile( |
126 | | - new URL(jarRemoteLocation), |
127 | | - new File(downloadLocation).toURI.toURL |
128 | | - ) |
| 126 | + val jar = URI.create(jarRemoteLocation) |
| 127 | + if (HADOOP_FS_SCHEMES.contains(jar.getScheme)) { |
| 128 | + val conf = kernel.sparkContext.hadoopConfiguration |
| 129 | + val jarPath = new Path(jarRemoteLocation) |
| 130 | + val fs = jarPath.getFileSystem(conf) |
| 131 | + val destPath = if (downloadLocation.startsWith("file:")) { |
| 132 | + new Path(downloadLocation) |
| 133 | + } else { |
| 134 | + new Path("file:" + downloadLocation) |
| 135 | + } |
| 136 | + |
| 137 | + fs.copyToLocalFile( |
| 138 | + false /* keep original file */, |
| 139 | + jarPath, destPath, |
| 140 | + true /* don't create checksum files */) |
| 141 | + } else { |
| 142 | + downloadFile( |
| 143 | + new URL(jarRemoteLocation), |
| 144 | + new File(downloadLocation).toURI.toURL |
| 145 | + ) |
| 146 | + } |
129 | 147 |
|
130 | 148 | // Report download finished |
131 | 149 | printStream.println(s"Finished download of $jarName") |
|
0 commit comments