diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 5a46d737aeeb7..e92e1fd72bf10 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -176,7 +176,8 @@ NULL #' #' @param x Column to compute on. Note the difference in the following methods: #' \itemize{ -#' \item \code{to_json}: it is the column containing the struct or array of the structs. +#' \item \code{to_json}: it is the column containing the struct, array of the structs, +#' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. #' } #' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains @@ -1700,8 +1701,9 @@ setMethod("to_date", }) #' @details -#' \code{to_json}: Converts a column containing a \code{structType} or array of \code{structType} -#' into a Column of JSON string. Resolving the Column can fail if an unsupported type is encountered. +#' \code{to_json}: Converts a column containing a \code{structType}, array of \code{structType}, +#' a \code{mapType} or array of \code{mapType} into a Column of JSON string. +#' Resolving the Column can fail if an unsupported type is encountered. #' #' @rdname column_collection_functions #' @aliases to_json to_json,Column-method @@ -1715,6 +1717,14 @@ setMethod("to_date", #' #' # Converts an array of structs into a JSON array #' df2 <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' +#' # Converts a map into a JSON object +#' df2 <- sql("SELECT map('name', 'Bob')) as people") +#' df2 <- mutate(df2, people_json = to_json(df2$people)) +#' +#' # Converts an array of maps into a JSON array +#' df2 <- sql("SELECT array(map('name', 'Bob'), map('name', 'Alice')) as people") #' df2 <- mutate(df2, people_json = to_json(df2$people))} #' @note to_json since 2.2.0 setMethod("to_json", signature(x = "Column"), diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 7abc8720473c1..85a7e0819cff7 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1491,6 +1491,14 @@ test_that("column functions", { j <- collect(select(df, alias(to_json(df$people), "json"))) expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- sql("SELECT map('name', 'Bob') as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "{\"name\":\"Bob\"}") + + df <- sql("SELECT array(map('name', 'Bob'), map('name', 'Alice')) as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + df <- read.json(mapTypeJsonPath) j <- collect(select(df, alias(to_json(df$info), "json"))) expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 43f57672d9544..76db0fb91e48a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -64,7 +64,8 @@ public final class UTF8String implements Comparable, Externalizable, 5, 5, 5, 5, 6, 6}; - private static boolean isLittleEndian = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; + private static final boolean IS_LITTLE_ENDIAN = + ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN; private static final UTF8String COMMA_UTF8 = UTF8String.fromString(","); public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); @@ -220,7 +221,7 @@ public long getPrefix() { // After getting the data, we use a mask to mask out data that is not part of the string. long p; long mask = 0; - if (isLittleEndian) { + if (IS_LITTLE_ENDIAN) { if (numBytes >= 8) { p = Platform.getLong(base, offset); } else if (numBytes > 4) { @@ -510,6 +511,21 @@ public UTF8String trim() { } } + /** + * Based on the given trim string, trim this string starting from both ends + * This method searches for each character in the source string, removes the character if it is found + * in the trim string, stops at the first not found. It calls the trimLeft first, then trimRight. + * It returns a new string in which both ends trim characters have been removed. + * @param trimString the trim character string + */ + public UTF8String trim(UTF8String trimString) { + if (trimString != null) { + return trimLeft(trimString).trimRight(trimString); + } else { + return null; + } + } + public UTF8String trimLeft() { int s = 0; // skip all of the space (0x20) in the left side @@ -522,6 +538,40 @@ public UTF8String trimLeft() { } } + /** + * Based on the given trim string, trim this string starting from left end + * This method searches each character in the source string starting from the left end, removes the character if it + * is in the trim string, stops at the first character which is not in the trim string, returns the new string. + * @param trimString the trim character string + */ + public UTF8String trimLeft(UTF8String trimString) { + if (trimString == null) return null; + // the searching byte position in the source string + int srchIdx = 0; + // the first beginning byte position of a non-matching character + int trimIdx = 0; + + while (srchIdx < numBytes) { + UTF8String searchChar = copyUTF8String(srchIdx, srchIdx + numBytesForFirstByte(this.getByte(srchIdx)) - 1); + int searchCharBytes = searchChar.numBytes; + // try to find the matching for the searchChar in the trimString set + if (trimString.find(searchChar, 0) >= 0) { + trimIdx += searchCharBytes; + } else { + // no matching, exit the search + break; + } + srchIdx += searchCharBytes; + } + + if (trimIdx >= numBytes) { + // empty string + return EMPTY_UTF8; + } else { + return copyUTF8String(trimIdx, numBytes - 1); + } + } + public UTF8String trimRight() { int e = numBytes - 1; // skip all of the space (0x20) in the right side @@ -535,6 +585,50 @@ public UTF8String trimRight() { } } + /** + * Based on the given trim string, trim this string starting from right end + * This method searches each character in the source string starting from the right end, removes the character if it + * is in the trim string, stops at the first character which is not in the trim string, returns the new string. + * @param trimString the trim character string + */ + public UTF8String trimRight(UTF8String trimString) { + if (trimString == null) return null; + int charIdx = 0; + // number of characters from the source string + int numChars = 0; + // array of character length for the source string + int[] stringCharLen = new int[numBytes]; + // array of the first byte position for each character in the source string + int[] stringCharPos = new int[numBytes]; + // build the position and length array + while (charIdx < numBytes) { + stringCharPos[numChars] = charIdx; + stringCharLen[numChars] = numBytesForFirstByte(getByte(charIdx)); + charIdx += stringCharLen[numChars]; + numChars ++; + } + + // index trimEnd points to the first no matching byte position from the right side of the source string. + int trimEnd = numBytes - 1; + while (numChars > 0) { + UTF8String searchChar = + copyUTF8String(stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); + if (trimString.find(searchChar, 0) >= 0) { + trimEnd -= stringCharLen[numChars - 1]; + } else { + break; + } + numChars --; + } + + if (trimEnd < 0) { + // empty string + return EMPTY_UTF8; + } else { + return copyUTF8String(0, trimEnd); + } + } + public UTF8String reverse() { byte[] result = new byte[this.numBytes]; @@ -1097,10 +1191,23 @@ public UTF8String copy() { @Override public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); - // TODO: compare 8 bytes as unsigned long - for (int i = 0; i < len; i ++) { + int wordMax = (len / 8) * 8; + long roffset = other.offset; + Object rbase = other.base; + for (int i = 0; i < wordMax; i += 8) { + long left = getLong(base, offset + i); + long right = getLong(rbase, roffset + i); + if (left != right) { + if (IS_LITTLE_ENDIAN) { + return Long.compareUnsigned(Long.reverseBytes(left), Long.reverseBytes(right)); + } else { + return Long.compareUnsigned(left, right); + } + } + } + for (int i = wordMax; i < len; i++) { // In UTF-8, the byte should be unsigned, so we should compare them as unsigned int. - int res = (getByte(i) & 0xFF) - (other.getByte(i) & 0xFF); + int res = (getByte(i) & 0xFF) - (Platform.getByte(rbase, roffset + i) & 0xFF); if (res != 0) { return res; } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index c376371abdf90..f0860018d5642 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -730,4 +730,61 @@ public void testToLong() throws IOException { assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); } } + + @Test + public void trimBothWithTrimString() { + assertEquals(fromString("hello"), fromString(" hello ").trim(fromString(" "))); + assertEquals(fromString("o"), fromString(" hello ").trim(fromString(" hle"))); + assertEquals(fromString("h e"), fromString("ooh e ooo").trim(fromString("o "))); + assertEquals(fromString(""), fromString("ooo...oooo").trim(fromString("o."))); + assertEquals(fromString("b"), fromString("%^b[]@").trim(fromString("][@^%"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trim(fromString(" "))); + + assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); + assertEquals(fromString("数"), fromString("a数b").trim(fromString("ab"))); + assertEquals(fromString(""), fromString("a").trim(fromString("a数b"))); + assertEquals(fromString(""), fromString("数数 数数数").trim(fromString("数 "))); + assertEquals(fromString("据砖头"), fromString("数]数[数据砖头#数数").trim(fromString("[数]#"))); + assertEquals(fromString("据砖头数数 "), fromString("数数数据砖头数数 ").trim(fromString("数"))); + } + + @Test + public void trimLeftWithTrimString() { + assertEquals(fromString(" hello "), fromString(" hello ").trimLeft(fromString(""))); + assertEquals(fromString(""), fromString("a").trimLeft(fromString("a"))); + assertEquals(fromString("b"), fromString("b").trimLeft(fromString("a"))); + assertEquals(fromString("ba"), fromString("ba").trimLeft(fromString("a"))); + assertEquals(fromString(""), fromString("aaaaaaa").trimLeft(fromString("a"))); + assertEquals(fromString("trim"), fromString("oabtrim").trimLeft(fromString("bao"))); + assertEquals(fromString("rim "), fromString("ooootrim ").trimLeft(fromString("otm"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft(fromString(" "))); + + assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft(fromString(" "))); + assertEquals(fromString("数"), fromString("数").trimLeft(fromString("a"))); + assertEquals(fromString("a"), fromString("a").trimLeft(fromString("数"))); + assertEquals(fromString("砖头数数"), fromString("数数数据砖头数数").trimLeft(fromString("据数"))); + assertEquals(fromString("据砖头数数"), fromString(" 数数数据砖头数数").trimLeft(fromString("数 "))); + assertEquals(fromString("据砖头数数"), fromString("aa数数数据砖头数数").trimLeft(fromString("a数砖"))); + assertEquals(fromString("$S,.$BR"), fromString(",,,,%$S,.$BR").trimLeft(fromString("%,"))); + } + + @Test + public void trimRightWithTrimString() { + assertEquals(fromString(" hello "), fromString(" hello ").trimRight(fromString(""))); + assertEquals(fromString(""), fromString("a").trimRight(fromString("a"))); + assertEquals(fromString("cc"), fromString("ccbaaaa").trimRight(fromString("ba"))); + assertEquals(fromString(""), fromString("aabbbbaaa").trimRight(fromString("ab"))); + assertEquals(fromString(" he"), fromString(" hello ").trimRight(fromString(" ol"))); + assertEquals(fromString("oohell"), fromString("oohellooo../*&").trimRight(fromString("./,&%*o"))); + + assertEquals(EMPTY_UTF8, fromString(" ").trimRight(fromString(" "))); + + assertEquals(fromString(" 数据砖头"), fromString(" 数据砖头 ").trimRight(fromString(" "))); + assertEquals(fromString("数数砖头"), fromString("数数砖头数aa数").trimRight(fromString("a数"))); + assertEquals(fromString(""), fromString("数数数据砖ab").trimRight(fromString("数据砖ab"))); + assertEquals(fromString("头"), fromString("头a???/").trimRight(fromString("数?/*&^%a"))); + assertEquals(fromString("头"), fromString("头数b数数 [").trimRight(fromString(" []数b"))); + } } diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java new file mode 100644 index 0000000000000..618bd42d0e65d --- /dev/null +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -0,0 +1,408 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import org.apache.spark.util.ThreadUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.annotation.concurrent.GuardedBy; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.InterruptedIOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +/** + * {@link InputStream} implementation which asynchronously reads ahead from the underlying input + * stream when specified amount of data has been read from the current buffer. It does it by maintaining + * two buffer - active buffer and read ahead buffer. Active buffer contains data which should be returned + * when a read() call is issued. The read ahead buffer is used to asynchronously read from the underlying + * input stream and once the current active buffer is exhausted, we flip the two buffers so that we can + * start reading from the read ahead buffer without being blocked in disk I/O. + */ +public class ReadAheadInputStream extends InputStream { + + private static final Logger logger = LoggerFactory.getLogger(ReadAheadInputStream.class); + + private ReentrantLock stateChangeLock = new ReentrantLock(); + + @GuardedBy("stateChangeLock") + private ByteBuffer activeBuffer; + + @GuardedBy("stateChangeLock") + private ByteBuffer readAheadBuffer; + + @GuardedBy("stateChangeLock") + private boolean endOfStream; + + @GuardedBy("stateChangeLock") + // true if async read is in progress + private boolean readInProgress; + + @GuardedBy("stateChangeLock") + // true if read is aborted due to an exception in reading from underlying input stream. + private boolean readAborted; + + @GuardedBy("stateChangeLock") + private Throwable readException; + + @GuardedBy("stateChangeLock") + // whether the close method is called. + private boolean isClosed; + + @GuardedBy("stateChangeLock") + // true when the close method will close the underlying input stream. This is valid only if + // `isClosed` is true. + private boolean isUnderlyingInputStreamBeingClosed; + + @GuardedBy("stateChangeLock") + // whether there is a read ahead task running, + private boolean isReading; + + // If the remaining data size in the current buffer is below this threshold, + // we issue an async read from the underlying input stream. + private final int readAheadThresholdInBytes; + + private final InputStream underlyingInputStream; + + private final ExecutorService executorService = ThreadUtils.newDaemonSingleThreadExecutor("read-ahead"); + + private final Condition asyncReadComplete = stateChangeLock.newCondition(); + + private static final ThreadLocal oneByte = ThreadLocal.withInitial(() -> new byte[1]); + + /** + * Creates a ReadAheadInputStream with the specified buffer size and read-ahead + * threshold + * + * @param inputStream The underlying input stream. + * @param bufferSizeInBytes The buffer size. + * @param readAheadThresholdInBytes If the active buffer has less data than the read-ahead + * threshold, an async read is triggered. + */ + public ReadAheadInputStream(InputStream inputStream, int bufferSizeInBytes, int readAheadThresholdInBytes) { + Preconditions.checkArgument(bufferSizeInBytes > 0, + "bufferSizeInBytes should be greater than 0, but the value is " + bufferSizeInBytes); + Preconditions.checkArgument(readAheadThresholdInBytes > 0 && + readAheadThresholdInBytes < bufferSizeInBytes, + "readAheadThresholdInBytes should be greater than 0 and less than bufferSizeInBytes, but the" + + "value is " + readAheadThresholdInBytes); + activeBuffer = ByteBuffer.allocate(bufferSizeInBytes); + readAheadBuffer = ByteBuffer.allocate(bufferSizeInBytes); + this.readAheadThresholdInBytes = readAheadThresholdInBytes; + this.underlyingInputStream = inputStream; + activeBuffer.flip(); + readAheadBuffer.flip(); + } + + private boolean isEndOfStream() { + return (!activeBuffer.hasRemaining() && !readAheadBuffer.hasRemaining() && endOfStream); + } + + private void checkReadException() throws IOException { + if (readAborted) { + Throwables.propagateIfPossible(readException, IOException.class); + throw new IOException(readException); + } + } + + /** Read data from underlyingInputStream to readAheadBuffer asynchronously. */ + private void readAsync() throws IOException { + stateChangeLock.lock(); + final byte[] arr = readAheadBuffer.array(); + try { + if (endOfStream || readInProgress) { + return; + } + checkReadException(); + readAheadBuffer.position(0); + readAheadBuffer.flip(); + readInProgress = true; + } finally { + stateChangeLock.unlock(); + } + executorService.execute(new Runnable() { + + @Override + public void run() { + stateChangeLock.lock(); + try { + if (isClosed) { + readInProgress = false; + return; + } + // Flip this so that the close method will not close the underlying input stream when we + // are reading. + isReading = true; + } finally { + stateChangeLock.unlock(); + } + + // Please note that it is safe to release the lock and read into the read ahead buffer + // because either of following two conditions will hold - 1. The active buffer has + // data available to read so the reader will not read from the read ahead buffer. + // 2. This is the first time read is called or the active buffer is exhausted, + // in that case the reader waits for this async read to complete. + // So there is no race condition in both the situations. + int read = 0; + Throwable exception = null; + try { + while (true) { + read = underlyingInputStream.read(arr); + if (0 != read) break; + } + } catch (Throwable ex) { + exception = ex; + if (ex instanceof Error) { + // `readException` may not be reported to the user. Rethrow Error to make sure at least + // The user can see Error in UncaughtExceptionHandler. + throw (Error) ex; + } + } finally { + stateChangeLock.lock(); + if (read < 0 || (exception instanceof EOFException)) { + endOfStream = true; + } else if (exception != null) { + readAborted = true; + readException = exception; + } else { + readAheadBuffer.limit(read); + } + readInProgress = false; + signalAsyncReadComplete(); + stateChangeLock.unlock(); + closeUnderlyingInputStreamIfNecessary(); + } + } + }); + } + + private void closeUnderlyingInputStreamIfNecessary() { + boolean needToCloseUnderlyingInputStream = false; + stateChangeLock.lock(); + try { + isReading = false; + if (isClosed && !isUnderlyingInputStreamBeingClosed) { + // close method cannot close underlyingInputStream because we were reading. + needToCloseUnderlyingInputStream = true; + } + } finally { + stateChangeLock.unlock(); + } + if (needToCloseUnderlyingInputStream) { + try { + underlyingInputStream.close(); + } catch (IOException e) { + logger.warn(e.getMessage(), e); + } + } + } + + private void signalAsyncReadComplete() { + stateChangeLock.lock(); + try { + asyncReadComplete.signalAll(); + } finally { + stateChangeLock.unlock(); + } + } + + private void waitForAsyncReadComplete() throws IOException { + stateChangeLock.lock(); + try { + while (readInProgress) { + asyncReadComplete.await(); + } + } catch (InterruptedException e) { + InterruptedIOException iio = new InterruptedIOException(e.getMessage()); + iio.initCause(e); + throw iio; + } finally { + stateChangeLock.unlock(); + } + checkReadException(); + } + + @Override + public int read() throws IOException { + byte[] oneByteArray = oneByte.get(); + return read(oneByteArray, 0, 1) == -1 ? -1 : oneByteArray[0] & 0xFF; + } + + @Override + public int read(byte[] b, int offset, int len) throws IOException { + if (offset < 0 || len < 0 || len > b.length - offset) { + throw new IndexOutOfBoundsException(); + } + if (len == 0) { + return 0; + } + stateChangeLock.lock(); + try { + return readInternal(b, offset, len); + } finally { + stateChangeLock.unlock(); + } + } + + /** + * flip the active and read ahead buffer + */ + private void swapBuffers() { + ByteBuffer temp = activeBuffer; + activeBuffer = readAheadBuffer; + readAheadBuffer = temp; + } + + /** + * Internal read function which should be called only from read() api. The assumption is that + * the stateChangeLock is already acquired in the caller before calling this function. + */ + private int readInternal(byte[] b, int offset, int len) throws IOException { + assert (stateChangeLock.isLocked()); + if (!activeBuffer.hasRemaining()) { + waitForAsyncReadComplete(); + if (readAheadBuffer.hasRemaining()) { + swapBuffers(); + } else { + // The first read or activeBuffer is skipped. + readAsync(); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return -1; + } + swapBuffers(); + } + } else { + checkReadException(); + } + len = Math.min(len, activeBuffer.remaining()); + activeBuffer.get(b, offset, len); + + if (activeBuffer.remaining() <= readAheadThresholdInBytes && !readAheadBuffer.hasRemaining()) { + readAsync(); + } + return len; + } + + @Override + public int available() throws IOException { + stateChangeLock.lock(); + // Make sure we have no integer overflow. + try { + return (int) Math.min((long) Integer.MAX_VALUE, + (long) activeBuffer.remaining() + readAheadBuffer.remaining()); + } finally { + stateChangeLock.unlock(); + } + } + + @Override + public long skip(long n) throws IOException { + if (n <= 0L) { + return 0L; + } + stateChangeLock.lock(); + long skipped; + try { + skipped = skipInternal(n); + } finally { + stateChangeLock.unlock(); + } + return skipped; + } + + /** + * Internal skip function which should be called only from skip() api. The assumption is that + * the stateChangeLock is already acquired in the caller before calling this function. + */ + private long skipInternal(long n) throws IOException { + assert (stateChangeLock.isLocked()); + waitForAsyncReadComplete(); + if (isEndOfStream()) { + return 0; + } + if (available() >= n) { + // we can skip from the internal buffers + int toSkip = (int) n; + if (toSkip <= activeBuffer.remaining()) { + // Only skipping from active buffer is sufficient + activeBuffer.position(toSkip + activeBuffer.position()); + if (activeBuffer.remaining() <= readAheadThresholdInBytes + && !readAheadBuffer.hasRemaining()) { + readAsync(); + } + return n; + } + // We need to skip from both active buffer and read ahead buffer + toSkip -= activeBuffer.remaining(); + activeBuffer.position(0); + activeBuffer.flip(); + readAheadBuffer.position(toSkip + readAheadBuffer.position()); + swapBuffers(); + readAsync(); + return n; + } else { + int skippedBytes = available(); + long toSkip = n - skippedBytes; + activeBuffer.position(0); + activeBuffer.flip(); + readAheadBuffer.position(0); + readAheadBuffer.flip(); + long skippedFromInputStream = underlyingInputStream.skip(toSkip); + readAsync(); + return skippedBytes + skippedFromInputStream; + } + } + + @Override + public void close() throws IOException { + boolean isSafeToCloseUnderlyingInputStream = false; + stateChangeLock.lock(); + try { + if (isClosed) { + return; + } + isClosed = true; + if (!isReading) { + // Nobody is reading, so we can close the underlying input stream in this method. + isSafeToCloseUnderlyingInputStream = true; + // Flip this to make sure the read ahead task will not close the underlying input stream. + isUnderlyingInputStreamBeingClosed = true; + } + } finally { + stateChangeLock.unlock(); + } + + try { + executorService.shutdownNow(); + executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS); + } catch (InterruptedException e) { + InterruptedIOException iio = new InterruptedIOException(e.getMessage()); + iio.initCause(e); + throw iio; + } finally { + if (isSafeToCloseUnderlyingInputStream) { + underlyingInputStream.close(); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 4099fb01f2f95..0efae16e9838c 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -89,13 +89,7 @@ public LongArray allocateArray(long size) { long required = size * 8L; MemoryBlock page = taskMemoryManager.allocatePage(required, this); if (page == null || page.size() < required) { - long got = 0; - if (page != null) { - got = page.size(); - taskMemoryManager.freePage(page, this); - } - taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throwOom(page, required); } used += required; return new LongArray(page); @@ -116,13 +110,7 @@ public void freeArray(LongArray array) { protected MemoryBlock allocatePage(long required) { MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); if (page == null || page.size() < required) { - long got = 0; - if (page != null) { - got = page.size(); - taskMemoryManager.freePage(page, this); - } - taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + throwOom(page, required); } used += page.size(); return page; @@ -152,4 +140,14 @@ public void freeMemory(long size) { taskMemoryManager.releaseExecutionMemory(size, this); used -= size; } + + private void throwOom(final MemoryBlock page, final long required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + } } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 610ace30f8a62..4fadfe36cd716 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -283,13 +283,7 @@ private void advanceToNextPage() { } else { currentPage = null; if (reader != null) { - // remove the spill file from disk - File file = spillWriters.removeFirst().getFile(); - if (file != null && file.exists()) { - if (!file.delete()) { - logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - } - } + handleFailedDelete(); } try { Closeables.close(reader, /* swallowIOException = */ false); @@ -307,13 +301,7 @@ private void advanceToNextPage() { public boolean hasNext() { if (numRecords == 0) { if (reader != null) { - // remove the spill file from disk - File file = spillWriters.removeFirst().getFile(); - if (file != null && file.exists()) { - if (!file.delete()) { - logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - } - } + handleFailedDelete(); } } return numRecords > 0; @@ -403,6 +391,14 @@ public long spill(long numBytes) throws IOException { public void remove() { throw new UnsupportedOperationException(); } + + private void handleFailedDelete() { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists() && !file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index de4464080ef55..39eda00dd7efb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -219,15 +219,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, inMemSorter.numRecords()); spillWriters.add(spillWriter); - final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); - while (sortedRecords.hasNext()) { - sortedRecords.loadNext(); - final Object baseObject = sortedRecords.getBaseObject(); - final long baseOffset = sortedRecords.getBaseOffset(); - final int recordLength = sortedRecords.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); - } - spillWriter.close(); + spillIterator(inMemSorter.getSortedIterator(), spillWriter); } final long spillSize = freeMemory(); @@ -488,6 +480,18 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { } } + private static void spillIterator(UnsafeSorterIterator inMemIterator, + UnsafeSorterSpillWriter spillWriter) throws IOException { + while (inMemIterator.hasNext()) { + inMemIterator.loadNext(); + final Object baseObject = inMemIterator.getBaseObject(); + final long baseOffset = inMemIterator.getBaseOffset(); + final int recordLength = inMemIterator.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); + } + spillWriter.close(); + } + /** * An UnsafeSorterIterator that support spilling. */ @@ -503,6 +507,7 @@ class SpillableIterator extends UnsafeSorterIterator { this.numRecords = inMemIterator.getNumRecords(); } + @Override public int getNumRecords() { return numRecords; } @@ -521,14 +526,7 @@ public long spill() throws IOException { // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); - while (inMemIterator.hasNext()) { - inMemIterator.loadNext(); - final Object baseObject = inMemIterator.getBaseObject(); - final long baseOffset = inMemIterator.getBaseOffset(); - final int recordLength = inMemIterator.getRecordLength(); - spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); - } - spillWriter.close(); + spillIterator(inMemIterator, spillWriter); spillWriters.add(spillWriter); nextUpstream = spillWriter.getReader(serializerManager); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 9521ab86a12d5..1e760b0b51988 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -17,20 +17,20 @@ package org.apache.spark.util.collection.unsafe.sort; -import java.io.*; - import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; - import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; import org.apache.spark.io.NioBufferedFileInputStream; +import org.apache.spark.io.ReadAheadInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; import org.apache.spark.unsafe.Platform; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.*; + /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). @@ -72,10 +72,23 @@ public UnsafeSorterSpillReader( bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES; } + final double readAheadFraction = + SparkEnv.get() == null ? 0.5 : + SparkEnv.get().conf().getDouble("spark.unsafe.sorter.spill.read.ahead.fraction", 0.5); + + final boolean readAheadEnabled = + SparkEnv.get() == null ? false : + SparkEnv.get().conf().getBoolean("spark.unsafe.sorter.spill.read.ahead.enabled", true); + final InputStream bs = new NioBufferedFileInputStream(file, (int) bufferSizeBytes); try { - this.in = serializerManager.wrapStream(blockId, bs); + if (readAheadEnabled) { + this.in = new ReadAheadInputStream(serializerManager.wrapStream(blockId, bs), + (int) bufferSizeBytes, (int) (bufferSizeBytes * readAheadFraction)); + } else { + this.in = serializerManager.wrapStream(blockId, bs); + } this.din = new DataInputStream(this.in); numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 7a5fb9a802354..119b426a9af34 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -217,7 +217,7 @@ private[spark] class ExecutorAllocationManager( * the scheduling task. */ def start(): Unit = { - listenerBus.addListener(listener) + listenerBus.addToManagementQueue(listener) val scheduleTask = new Runnable() { override def run(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 5242ab6f55235..ff960b396dbf1 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -63,7 +63,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) this(sc, new SystemClock) } - sc.addSparkListener(this) + sc.listenerBus.addToManagementQueue(this) override val rpcEnv: RpcEnv = sc.env.rpcEnv diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 136f0af7b2c9e..1821bc87bf626 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -419,7 +419,7 @@ class SparkContext(config: SparkConf) extends Logging { // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. _jobProgressListener = new JobProgressListener(_conf) - listenerBus.addListener(jobProgressListener) + listenerBus.addToStatusQueue(jobProgressListener) // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) @@ -442,7 +442,7 @@ class SparkContext(config: SparkConf) extends Logging { _ui = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, _conf, listenerBus, _jobProgressListener, + Some(SparkUI.createLiveUI(this, _conf, _jobProgressListener, _env.securityManager, appName, startTime = startTime)) } else { // For tests, do not enable the UI @@ -522,7 +522,7 @@ class SparkContext(config: SparkConf) extends Logging { new EventLoggingListener(_applicationId, _applicationAttemptId, _eventLogDir.get, _conf, _hadoopConfiguration) logger.start() - listenerBus.addListener(logger) + listenerBus.addToEventLogQueue(logger) Some(logger) } else { None @@ -1563,7 +1563,7 @@ class SparkContext(config: SparkConf) extends Logging { */ @DeveloperApi def addSparkListener(listener: SparkListenerInterface) { - listenerBus.addListener(listener) + listenerBus.addToSharedQueue(listener) } /** @@ -1879,8 +1879,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def stop(): Unit = { if (LiveListenerBus.withinListenerThread.value) { - throw new SparkException( - s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") + throw new SparkException(s"Cannot stop SparkContext within listener bus thread.") } // Use the stopping variable to ensure no contention for the stop scenario. // Still track the stopped variable for use elsewhere in the code. @@ -2378,7 +2377,7 @@ class SparkContext(config: SparkConf) extends Logging { " parameter from breaking Spark's ability to find a valid constructor.") } } - listenerBus.addListener(listener) + listenerBus.addToSharedQueue(listener) logInfo(s"Registered listener $className") } } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index 51c3d9b158cbe..ecc82d7ac8001 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -94,7 +94,7 @@ private[deploy] object DependencyUtils { hadoopConf: Configuration, secMgr: SecurityManager): String = { require(fileList != null, "fileList cannot be null.") - fileList.split(",") + Utils.stringToSeq(fileList) .map(downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr)) .mkString(",") } @@ -121,6 +121,11 @@ private[deploy] object DependencyUtils { uri.getScheme match { case "file" | "local" => path + case "http" | "https" | "ftp" if Utils.isTesting => + // This is only used for SparkSubmitSuite unit test. Instead of downloading file remotely, + // return a dummy local path instead. + val file = new File(uri.getPath) + new File(targetDir, file.getName).toURI.toString case _ => val fname = new Path(uri).getName() val localFile = Utils.doFetchFile(uri.toString(), targetDir, fname, sparkConf, secMgr, @@ -131,7 +136,7 @@ private[deploy] object DependencyUtils { def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { require(paths != null, "paths cannot be null.") - paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path => + Utils.stringToSeq(paths).flatMap { path => val uri = Utils.resolveURI(path) uri.getScheme match { case "local" | "http" | "https" | "ftp" => Array(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index ea9c9bdaede76..286a4379d2040 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -25,11 +25,11 @@ import java.text.ParseException import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import scala.util.Properties +import scala.util.{Properties, Try} import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.ivy.Ivy @@ -48,6 +48,7 @@ import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util._ @@ -367,6 +368,52 @@ object SparkSubmit extends CommandLineUtils with Logging { }.orNull } + // When running in YARN, for some remote resources with scheme: + // 1. Hadoop FileSystem doesn't support them. + // 2. We explicitly bypass Hadoop FileSystem with "spark.yarn.dist.forceDownloadSchemes". + // We will download them to local disk prior to add to YARN's distributed cache. + // For yarn client mode, since we already download them with above code, so we only need to + // figure out the local path and replace the remote one. + if (clusterManager == YARN) { + sparkConf.setIfMissing(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + val secMgr = new SecurityManager(sparkConf) + val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) + + def shouldDownload(scheme: String): Boolean = { + forceDownloadSchemes.contains(scheme) || + Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure + } + + def downloadResource(resource: String): String = { + val uri = Utils.resolveURI(resource) + uri.getScheme match { + case "local" | "file" => resource + case e if shouldDownload(e) => + val file = new File(targetDir, new Path(uri).getName) + if (file.exists()) { + file.toURI.toString + } else { + downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + } + case _ => uri.toString + } + } + + args.primaryResource = Option(args.primaryResource).map { downloadResource }.orNull + args.files = Option(args.files).map { files => + Utils.stringToSeq(files).map(downloadResource).mkString(",") + }.orNull + args.pyFiles = Option(args.pyFiles).map { pyFiles => + Utils.stringToSeq(pyFiles).map(downloadResource).mkString(",") + }.orNull + args.jars = Option(args.jars).map { jars => + Utils.stringToSeq(jars).map(downloadResource).mkString(",") + }.orNull + args.archives = Option(args.archives).map { archives => + Utils.stringToSeq(archives).map(downloadResource).mkString(",") + }.orNull + } + // If we're running a python app, set the main class to our specific python runner if (args.isPython && deployMode == CLIENT) { if (args.primaryResource == PYSPARK_SHELL) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0d3769a735869..44a2815b81a73 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -385,4 +385,29 @@ package object config { .checkValue(v => v > 0 && v <= Int.MaxValue, s"The buffer size must be greater than 0 and less than ${Int.MaxValue}.") .createWithDefault(1024 * 1024) + + private[spark] val UNROLL_MEMORY_CHECK_PERIOD = + ConfigBuilder("spark.storage.unrollMemoryCheckPeriod") + .internal() + .doc("The memory check period is used to determine how often we should check whether " + + "there is a need to request more memory when we try to unroll the given block in memory.") + .longConf + .createWithDefault(16) + + private[spark] val UNROLL_MEMORY_GROWTH_FACTOR = + ConfigBuilder("spark.storage.unrollMemoryGrowthFactor") + .internal() + .doc("Memory to request as a multiple of the size that used to unroll the block.") + .doubleConf + .createWithDefault(1.5) + + private[spark] val FORCE_DOWNLOAD_SCHEMES = + ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") + .doc("Comma-separated list of schemes for which files will be downloaded to the " + + "local disk prior to being added to YARN's distributed cache. For use in cases " + + "where the YARN service does not support schemes that are supported by Spark, like http, " + + "https and ftp.") + .stringConf + .toSequence + .createWithDefault(Nil) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 23e31823f4930..ac33e68abb490 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -68,8 +68,8 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { - case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) - case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) + case Some("udp") => new GraphiteUDP(host, port) + case Some("tcp") | None => new Graphite(host, port) case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala new file mode 100644 index 0000000000000..8605e1da161c7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} + +import com.codahale.metrics.{Gauge, Timer} + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +/** + * An asynchronous queue for events. All events posted to this queue will be delivered to the child + * listeners in a separate thread. + * + * Delivery will only begin when the `start()` method is called. The `stop()` method should be + * called when no more events need to be delivered. + */ +private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) + extends SparkListenerBus + with Logging { + + import AsyncEventQueue._ + + // Cap the capacity of the queue so we get an explicit error (rather than an OOM exception) if + // it's perpetually being added to more quickly than it's being drained. + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]( + conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) + + // Keep the event count separately, so that waitUntilEmpty() can be implemented properly; + // this allows that method to return only when the events in the queue have been fully + // processed (instead of just dequeued). + private val eventCount = new AtomicLong() + + /** A counter for dropped events. It will be reset every time we log it. */ + private val droppedEventsCounter = new AtomicLong(0L) + + /** When `droppedEventsCounter` was logged last time in milliseconds. */ + @volatile private var lastReportTimestamp = 0L + + private val logDroppedEvent = new AtomicBoolean(false) + + private var sc: SparkContext = null + + private val started = new AtomicBoolean(false) + private val stopped = new AtomicBoolean(false) + + private val droppedEvents = metrics.metricRegistry.counter(s"queue.$name.numDroppedEvents") + private val processingTime = metrics.metricRegistry.timer(s"queue.$name.listenerProcessingTime") + + // Remove the queue size gauge first, in case it was created by a previous incarnation of + // this queue that was removed from the listener bus. + metrics.metricRegistry.remove(s"queue.$name.size") + metrics.metricRegistry.register(s"queue.$name.size", new Gauge[Int] { + override def getValue: Int = eventQueue.size() + }) + + private val dispatchThread = new Thread(s"spark-listener-group-$name") { + setDaemon(true) + override def run(): Unit = Utils.tryOrStopSparkContext(sc) { + dispatch() + } + } + + private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { + try { + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() + } + eventCount.decrementAndGet() + next = eventQueue.take() + } + eventCount.decrementAndGet() + } catch { + case ie: InterruptedException => + logInfo(s"Stopping listener queue $name.", ie) + } + } + + override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { + metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface])) + } + + /** + * Start an asynchronous thread to dispatch events to the underlying listeners. + * + * @param sc Used to stop the SparkContext in case the async dispatcher fails. + */ + private[scheduler] def start(sc: SparkContext): Unit = { + if (started.compareAndSet(false, true)) { + this.sc = sc + dispatchThread.start() + } else { + throw new IllegalStateException(s"$name already started!") + } + } + + /** + * Stop the listener bus. It will wait until the queued events have been processed, but new + * events will be dropped. + */ + private[scheduler] def stop(): Unit = { + if (!started.get()) { + throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + } + if (stopped.compareAndSet(false, true)) { + eventQueue.put(POISON_PILL) + eventCount.incrementAndGet() + } + dispatchThread.join() + } + + def post(event: SparkListenerEvent): Unit = { + if (stopped.get()) { + return + } + + eventCount.incrementAndGet() + if (eventQueue.offer(event)) { + return + } + + eventCount.decrementAndGet() + droppedEvents.inc() + droppedEventsCounter.incrementAndGet() + if (logDroppedEvent.compareAndSet(false, true)) { + // Only log the following message once to avoid duplicated annoying logs. + logError(s"Dropping event from queue $name. " + + "This likely means one of the listeners is too slow and cannot keep up with " + + "the rate at which tasks are being started by the scheduler.") + } + logTrace(s"Dropping event $event") + + val droppedCount = droppedEventsCounter.get + if (droppedCount > 0) { + // Don't log too frequently + if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { + // There may be multiple threads trying to decrease droppedEventsCounter. + // Use "compareAndSet" to make sure only one thread can win. + // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and + // then that thread will update it. + if (droppedEventsCounter.compareAndSet(droppedCount, 0)) { + val prevLastReportTimestamp = lastReportTimestamp + lastReportTimestamp = System.currentTimeMillis() + val previous = new java.util.Date(prevLastReportTimestamp) + logWarning(s"Dropped $droppedEvents events from $name since $previous.") + } + } + } + } + + /** + * For testing only. Wait until there are no more events in the queue. + * + * @return true if the queue is empty. + */ + def waitUntilEmpty(deadline: Long): Boolean = { + while (eventCount.get() != 0) { + if (System.currentTimeMillis > deadline) { + return false + } + Thread.sleep(10) + } + true + } + +} + +private object AsyncEventQueue { + + val POISON_PILL = new SparkListenerEvent() { } + +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 7d5e9809dd7b2..2f93c497c5771 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,20 +17,22 @@ package org.apache.spark.scheduler +import java.util.{List => JList} import java.util.concurrent._ import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.reflect.ClassTag import scala.util.DynamicVariable -import com.codahale.metrics.{Counter, Gauge, MetricRegistry, Timer} +import com.codahale.metrics.{Counter, MetricRegistry, Timer} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.util.Utils /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. @@ -39,20 +41,13 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { - - self => +private[spark] class LiveListenerBus(conf: SparkConf) { import LiveListenerBus._ private var sparkContext: SparkContext = _ - // Cap the capacity of the event queue so we get an explicit error (rather than - // an OOM exception) if it's perpetually being added to more quickly than it's being drained. - private val eventQueue = - new LinkedBlockingQueue[SparkListenerEvent](conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) - - private[spark] val metrics = new LiveListenerBusMetrics(conf, eventQueue) + private[spark] val metrics = new LiveListenerBusMetrics(conf) // Indicate if `start()` is called private val started = new AtomicBoolean(false) @@ -65,53 +60,76 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { /** When `droppedEventsCounter` was logged last time in milliseconds. */ @volatile private var lastReportTimestamp = 0L - // Indicate if we are processing some event - // Guarded by `self` - private var processingEvent = false - - private val logDroppedEvent = new AtomicBoolean(false) - - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) - - private val listenerThread = new Thread(name) { - setDaemon(true) - override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - LiveListenerBus.withinListenerThread.withValue(true) { - val timer = metrics.eventProcessingTime - while (true) { - eventLock.acquire() - self.synchronized { - processingEvent = true - } - try { - val event = eventQueue.poll - if (event == null) { - // Get out of the while loop and shutdown the daemon thread - if (!stopped.get) { - throw new IllegalStateException("Polling `null` from eventQueue means" + - " the listener bus has been stopped. So `stopped` must be true") - } - return - } - val timerContext = timer.time() - try { - postToAll(event) - } finally { - timerContext.stop() - } - } finally { - self.synchronized { - processingEvent = false - } - } + private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + + /** Add a listener to queue shared by all non-internal listeners. */ + def addToSharedQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, SHARED_QUEUE) + } + + /** Add a listener to the executor management queue. */ + def addToManagementQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, EXECUTOR_MANAGEMENT_QUEUE) + } + + /** Add a listener to the application status queue. */ + def addToStatusQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, APP_STATUS_QUEUE) + } + + /** Add a listener to the event log queue. */ + def addToEventLogQueue(listener: SparkListenerInterface): Unit = { + addToQueue(listener, EVENT_LOG_QUEUE) + } + + /** + * Add a listener to a specific queue, creating a new queue if needed. Queues are independent + * of each other (each one uses a separate thread for delivering events), allowing slower + * listeners to be somewhat isolated from others. + */ + private def addToQueue(listener: SparkListenerInterface, queue: String): Unit = synchronized { + if (stopped.get()) { + throw new IllegalStateException("LiveListenerBus is stopped.") + } + + queues.asScala.find(_.name == queue) match { + case Some(queue) => + queue.addListener(listener) + + case None => + val newQueue = new AsyncEventQueue(queue, conf, metrics) + newQueue.addListener(listener) + if (started.get()) { + newQueue.start(sparkContext) } - } + queues.add(newQueue) } } - override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { - metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface])) + def removeListener(listener: SparkListenerInterface): Unit = synchronized { + // Remove listener from all queues it was added to, and stop queues that have become empty. + queues.asScala + .filter { queue => + queue.removeListener(listener) + queue.listeners.isEmpty() + } + .foreach { toRemove => + if (started.get() && !stopped.get()) { + toRemove.stop() + } + queues.remove(toRemove) + } + } + + /** Post an event to all queues. */ + def post(event: SparkListenerEvent): Unit = { + if (!stopped.get()) { + metrics.numEventsPosted.inc() + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } + } } /** @@ -123,46 +141,14 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { * * @param sc Used to stop the SparkContext in case the listener thread dies. */ - def start(sc: SparkContext, metricsSystem: MetricsSystem): Unit = { - if (started.compareAndSet(false, true)) { - sparkContext = sc - metricsSystem.registerSource(metrics) - listenerThread.start() - } else { - throw new IllegalStateException(s"$name already started!") - } - } - - def post(event: SparkListenerEvent): Unit = { - if (stopped.get) { - // Drop further events to make `listenerThread` exit ASAP - logDebug(s"$name has already stopped! Dropping event $event") - return - } - metrics.numEventsPosted.inc() - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - onDropEvent(event) + def start(sc: SparkContext, metricsSystem: MetricsSystem): Unit = synchronized { + if (!started.compareAndSet(false, true)) { + throw new IllegalStateException("LiveListenerBus already started.") } - val droppedEvents = droppedEventsCounter.get - if (droppedEvents > 0) { - // Don't log too frequently - if (System.currentTimeMillis() - lastReportTimestamp >= 60 * 1000) { - // There may be multiple threads trying to decrease droppedEventsCounter. - // Use "compareAndSet" to make sure only one thread can win. - // And if another thread is increasing droppedEventsCounter, "compareAndSet" will fail and - // then that thread will update it. - if (droppedEventsCounter.compareAndSet(droppedEvents, 0)) { - val prevLastReportTimestamp = lastReportTimestamp - lastReportTimestamp = System.currentTimeMillis() - logWarning(s"Dropped $droppedEvents SparkListenerEvents since " + - new java.util.Date(prevLastReportTimestamp)) - } - } - } + this.sparkContext = sc + queues.asScala.foreach(_.start(sc)) + metricsSystem.registerSource(metrics) } /** @@ -173,80 +159,64 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { */ @throws(classOf[TimeoutException]) def waitUntilEmpty(timeoutMillis: Long): Unit = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!queueIsEmpty) { - if (System.currentTimeMillis > finishTime) { - throw new TimeoutException( - s"The event queue is not empty after $timeoutMillis milliseconds") + val deadline = System.currentTimeMillis + timeoutMillis + queues.asScala.foreach { queue => + if (!queue.waitUntilEmpty(deadline)) { + throw new TimeoutException(s"The event queue is not empty after $timeoutMillis ms.") } - /* Sleep rather than using wait/notify, because this is used only for testing and - * wait/notify add overhead in the general case. */ - Thread.sleep(10) } } - /** - * For testing only. Return whether the listener daemon thread is still alive. - * Exposed for testing. - */ - def listenerThreadIsAlive: Boolean = listenerThread.isAlive - - /** - * Return whether the event queue is empty. - * - * The use of synchronized here guarantees that all events that once belonged to this queue - * have already been processed by all attached listeners, if this returns true. - */ - private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } - /** * Stop the listener bus. It will wait until the queued events have been processed, but drop the * new events after stopping. */ def stop(): Unit = { if (!started.get()) { - throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + throw new IllegalStateException(s"Attempted to stop bus that has not yet started!") } - if (stopped.compareAndSet(false, true)) { - // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know - // `stop` is called. - eventLock.release() - listenerThread.join() - } else { - // Keep quiet + + if (!stopped.compareAndSet(false, true)) { + return } - } - /** - * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be - * notified with the dropped events. - * - * Note: `onDropEvent` can be called in any thread. - */ - def onDropEvent(event: SparkListenerEvent): Unit = { - metrics.numDroppedEvents.inc() - droppedEventsCounter.incrementAndGet() - if (logDroppedEvent.compareAndSet(false, true)) { - // Only log the following message once to avoid duplicated annoying logs. - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with " + - "the rate at which tasks are being started by the scheduler.") + synchronized { + queues.asScala.foreach(_.stop()) + queues.clear() } - logTrace(s"Dropping event $event") } + + // For testing only. + private[spark] def findListenersByClass[T <: SparkListenerInterface : ClassTag](): Seq[T] = { + queues.asScala.flatMap { queue => queue.findListenersByClass[T]() } + } + + // For testing only. + private[spark] def listeners: JList[SparkListenerInterface] = { + queues.asScala.flatMap(_.listeners.asScala).asJava + } + + // For testing only. + private[scheduler] def activeQueues(): Set[String] = { + queues.asScala.map(_.name).toSet + } + } private[spark] object LiveListenerBus { // Allows for Context to check whether stop() call is made within listener thread val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) - /** The thread name of Spark listener bus */ - val name = "SparkListenerBus" + private[scheduler] val SHARED_QUEUE = "shared" + + private[scheduler] val APP_STATUS_QUEUE = "appStatus" + + private[scheduler] val EXECUTOR_MANAGEMENT_QUEUE = "executorManagement" + + private[scheduler] val EVENT_LOG_QUEUE = "eventLog" } -private[spark] class LiveListenerBusMetrics( - conf: SparkConf, - queue: LinkedBlockingQueue[_]) +private[spark] class LiveListenerBusMetrics(conf: SparkConf) extends Source with Logging { override val sourceName: String = "LiveListenerBus" @@ -260,25 +230,6 @@ private[spark] class LiveListenerBusMetrics( */ val numEventsPosted: Counter = metricRegistry.counter(MetricRegistry.name("numEventsPosted")) - /** - * The total number of events that were dropped without being delivered to listeners. - */ - val numDroppedEvents: Counter = metricRegistry.counter(MetricRegistry.name("numEventsDropped")) - - /** - * The amount of time taken to post a single event to all listeners. - */ - val eventProcessingTime: Timer = metricRegistry.timer(MetricRegistry.name("eventProcessingTime")) - - /** - * The number of messages waiting in the queue. - */ - val queueSize: Gauge[Int] = { - metricRegistry.register(MetricRegistry.name("queueSize"), new Gauge[Int]{ - override def getValue: Int = queue.size() - }) - } - // Guarded by synchronization. private val perListenerClassTimers = mutable.Map[String, Timer]() @@ -303,5 +254,5 @@ private[spark] class LiveListenerBusMetrics( } } } -} +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index aaacabe79ace4..b4b5938c307e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -988,11 +988,16 @@ private[spark] class BlockManager( logWarning(s"Putting block $blockId failed") } res + } catch { + // Since removeBlockInternal may throw exception, + // we should print exception first to show root cause. + case NonFatal(e) => + logWarning(s"Putting block $blockId failed due to exception $e.") + throw e } finally { // This cleanup is performed in a finally block rather than a `catch` to avoid having to // catch and properly re-throw InterruptedException. if (exceptionWasThrown) { - logWarning(s"Putting block $blockId failed due to an exception") // If an exception was thrown then it's possible that the code in `putBody` has already // notified the master about the availability of this block, so we need to send an update // to remove this block location. diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 90e3af2d0ec74..eb2201d142ffb 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -29,6 +29,7 @@ import com.google.common.io.ByteStreams import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.{UNROLL_MEMORY_CHECK_PERIOD, UNROLL_MEMORY_GROWTH_FACTOR} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel, StreamBlockId} @@ -190,11 +191,11 @@ private[spark] class MemoryStore( // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory - val memoryCheckPeriod = 16 + val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size - val memoryGrowthFactor = 1.5 + val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Keep track of unroll memory used by this particular block / putIterator() operation var unrollMemoryUsedByThisBlock = 0L // Underlying vector for unrolling the block @@ -325,6 +326,12 @@ private[spark] class MemoryStore( // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true + // Number of elements unrolled so far + var elementsUnrolled = 0L + // How often to check whether we need to request more memory + val memoryCheckPeriod = conf.get(UNROLL_MEMORY_CHECK_PERIOD) + // Memory to request as a multiple of current bbos size + val memoryGrowthFactor = conf.get(UNROLL_MEMORY_GROWTH_FACTOR) // Initial per-task memory to request for unrolling blocks (bytes). val initialMemoryThreshold = unrollMemoryThreshold // Keep track of unroll memory used by this particular block / putIterator() operation @@ -359,7 +366,7 @@ private[spark] class MemoryStore( def reserveAdditionalMemoryIfNecessary(): Unit = { if (bbos.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + val amountToRequest = (bbos.size * memoryGrowthFactor - unrollMemoryUsedByThisBlock).toLong keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest @@ -370,7 +377,10 @@ private[spark] class MemoryStore( // Unroll this block safely, checking whether we have exceeded our threshold while (values.hasNext && keepUnrolling) { serializationStream.writeObject(values.next())(classTag) - reserveAdditionalMemoryIfNecessary() + elementsUnrolled += 1 + if (elementsUnrolled % memoryCheckPeriod == 0) { + reserveAdditionalMemoryIfNecessary() + } } // Make sure that we have enough memory to store the block. By this point, it is possible that diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f3fcf2778d39e..6e94073238a56 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -162,13 +162,14 @@ private[spark] object SparkUI { def createLiveUI( sc: SparkContext, conf: SparkConf, - listenerBus: SparkListenerBus, jobProgressListener: JobProgressListener, securityManager: SecurityManager, appName: String, startTime: Long): SparkUI = { - create(Some(sc), conf, listenerBus, securityManager, appName, - jobProgressListener = Some(jobProgressListener), startTime = startTime) + create(Some(sc), conf, + sc.listenerBus.addToStatusQueue, + securityManager, appName, jobProgressListener = Some(jobProgressListener), + startTime = startTime) } def createHistoryUI( @@ -179,8 +180,7 @@ private[spark] object SparkUI { basePath: String, lastUpdateTime: Option[Long], startTime: Long): SparkUI = { - val sparkUI = create( - None, conf, listenerBus, securityManager, appName, basePath, + val sparkUI = create(None, conf, listenerBus.addListener, securityManager, appName, basePath, lastUpdateTime = lastUpdateTime, startTime = startTime) val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], @@ -202,7 +202,7 @@ private[spark] object SparkUI { private def create( sc: Option[SparkContext], conf: SparkConf, - listenerBus: SparkListenerBus, + addListenerFn: SparkListenerInterface => Unit, securityManager: SecurityManager, appName: String, basePath: String = "", @@ -212,7 +212,7 @@ private[spark] object SparkUI { val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { val listener = new JobProgressListener(conf) - listenerBus.addListener(listener) + addListenerFn(listener) listener } @@ -222,11 +222,11 @@ private[spark] object SparkUI { val storageListener = new StorageListener(storageStatusListener) val operationGraphListener = new RDDOperationGraphListener(conf) - listenerBus.addListener(environmentListener) - listenerBus.addListener(storageStatusListener) - listenerBus.addListener(executorsListener) - listenerBus.addListener(storageListener) - listenerBus.addListener(operationGraphListener) + addListenerFn(environmentListener) + addListenerFn(storageStatusListener) + addListenerFn(executorsListener) + addListenerFn(storageListener) + addListenerFn(operationGraphListener) new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, executorsListener, _jobProgressListener, storageListener, operationGraphListener, diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index d9c87f69d8a54..5acec0d0f54c9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -110,7 +110,7 @@ private[spark] object UIData { def hasOutput: Boolean = outputBytes > 0 def hasShuffleRead: Boolean = shuffleReadTotalBytes > 0 def hasShuffleWrite: Boolean = shuffleWriteBytes > 0 - def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 && diskBytesSpilled > 0 + def hasBytesSpilled: Boolean = memoryBytesSpilled > 0 || diskBytesSpilled > 0 } /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bc08808a4d292..836e33c36d9a1 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2684,6 +2684,9 @@ private[spark] object Utils extends Logging { redact(redactionPattern, kvs.toArray) } + def stringToSeq(str: String): Seq[String] = { + str.split(",").map(_.trim()).filter(_.nonEmpty) + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java similarity index 87% rename from core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java rename to core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java index 2c1a34a607592..3440e1aea2f46 100644 --- a/core/src/test/java/org/apache/spark/io/NioBufferedFileInputStreamSuite.java +++ b/core/src/test/java/org/apache/spark/io/GenericFileInputStreamSuite.java @@ -31,11 +31,13 @@ /** * Tests functionality of {@link NioBufferedFileInputStream} */ -public class NioBufferedFileInputStreamSuite { +public abstract class GenericFileInputStreamSuite { private byte[] randomBytes; - private File inputFile; + protected File inputFile; + + protected InputStream inputStream; @Before public void setUp() throws IOException { @@ -52,7 +54,6 @@ public void tearDown() { @Test public void testReadOneByte() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < randomBytes.length; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -60,7 +61,6 @@ public void testReadOneByte() throws IOException { @Test public void testReadMultipleBytes() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); byte[] readBytes = new byte[8 * 1024]; int i = 0; while (i < randomBytes.length) { @@ -74,7 +74,6 @@ public void testReadMultipleBytes() throws IOException { @Test public void testBytesSkipped() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); assertEquals(1024, inputStream.skip(1024)); for (int i = 1024; i < randomBytes.length; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); @@ -83,7 +82,6 @@ public void testBytesSkipped() throws IOException { @Test public void testBytesSkippedAfterRead() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < 1024; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -95,7 +93,6 @@ public void testBytesSkippedAfterRead() throws IOException { @Test public void testNegativeBytesSkippedAfterRead() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); for (int i = 0; i < 1024; i++) { assertEquals(randomBytes[i], (byte) inputStream.read()); } @@ -111,7 +108,6 @@ public void testNegativeBytesSkippedAfterRead() throws IOException { @Test public void testSkipFromFileChannel() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile, 10); // Since the buffer is smaller than the skipped bytes, this will guarantee // we skip from underlying file channel. assertEquals(1024, inputStream.skip(1024)); @@ -128,7 +124,6 @@ public void testSkipFromFileChannel() throws IOException { @Test public void testBytesSkippedAfterEOF() throws IOException { - InputStream inputStream = new NioBufferedFileInputStream(inputFile); assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1)); assertEquals(-1, inputStream.read()); } diff --git a/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java new file mode 100644 index 0000000000000..211b33a1a9fb0 --- /dev/null +++ b/core/src/test/java/org/apache/spark/io/NioBufferedInputStreamSuite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.junit.Before; + +import java.io.IOException; + +/** + * Tests functionality of {@link NioBufferedFileInputStream} + */ +public class NioBufferedInputStreamSuite extends GenericFileInputStreamSuite { + + @Before + public void setUp() throws IOException { + super.setUp(); + inputStream = new NioBufferedFileInputStream(inputFile); + } +} diff --git a/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java new file mode 100644 index 0000000000000..5008f93b7e409 --- /dev/null +++ b/core/src/test/java/org/apache/spark/io/ReadAheadInputStreamSuite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.io; + +import org.junit.Before; + +import java.io.IOException; + +/** + * Tests functionality of {@link NioBufferedFileInputStream} + */ +public class ReadAheadInputStreamSuite extends GenericFileInputStreamSuite { + + @Before + public void setUp() throws IOException { + super.setUp(); + inputStream = new ReadAheadInputStream(new NioBufferedFileInputStream(inputFile), 8 * 1024, 4 * 1024); + } +} diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 7da4bae0ab7eb..a91e09b7cb69f 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -49,6 +49,11 @@ class ExecutorAllocationManagerSuite contexts.foreach(_.stop()) } + private def post(bus: LiveListenerBus, event: SparkListenerEvent): Unit = { + bus.post(event) + bus.waitUntilEmpty(1000) + } + test("verify min/max executors") { val conf = new SparkConf() .setMaster("myDummyLocalExternalClusterManager") @@ -95,7 +100,7 @@ class ExecutorAllocationManagerSuite test("add executors") { sc = createSparkContext(1, 10, 1) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Keep adding until the limit is reached assert(numExecutorsTarget(manager) === 1) @@ -140,7 +145,7 @@ class ExecutorAllocationManagerSuite test("add executors capped by num pending tasks") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 5))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 5))) // Verify that we're capped at number of tasks in the stage assert(numExecutorsTarget(manager) === 0) @@ -156,10 +161,10 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that running a task doesn't affect the target - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 3))) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 3))) + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) - sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 6) @@ -172,9 +177,9 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that re-running a task doesn't blow things up - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 3))) - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, 3))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 9) assert(numExecutorsToAdd(manager) === 2) @@ -183,7 +188,7 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that running a task once we're at our limit doesn't blow things up - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) assert(addExecutors(manager) === 0) assert(numExecutorsTarget(manager) === 10) } @@ -193,13 +198,13 @@ class ExecutorAllocationManagerSuite val manager = sc.executorAllocationManager.get // Verify that we're capped at number of tasks including the speculative ones in the stage - sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) assert(addExecutors(manager) === 1) - sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) - sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1)) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) + post(sc.listenerBus, SparkListenerSpeculativeTaskSubmitted(1)) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 2))) assert(numExecutorsTarget(manager) === 1) assert(numExecutorsToAdd(manager) === 2) assert(addExecutors(manager) === 2) @@ -210,13 +215,13 @@ class ExecutorAllocationManagerSuite assert(numExecutorsToAdd(manager) === 1) // Verify that running a task doesn't affect the target - sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 0) assert(numExecutorsToAdd(manager) === 1) // Verify that running a speculative task doesn't affect the target - sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-2", true))) + post(sc.listenerBus, SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-2", true))) assert(numExecutorsTarget(manager) === 5) assert(addExecutors(manager) === 0) assert(numExecutorsToAdd(manager) === 1) @@ -225,7 +230,7 @@ class ExecutorAllocationManagerSuite test("cancel pending executors when no longer needed") { sc = createSparkContext(0, 10, 0) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, 5))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, 5))) assert(numExecutorsTarget(manager) === 0) assert(numExecutorsToAdd(manager) === 1) @@ -236,15 +241,15 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 3) val task1Info = createTaskInfo(0, 0, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task1Info)) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, task1Info)) assert(numExecutorsToAdd(manager) === 4) assert(addExecutors(manager) === 2) val task2Info = createTaskInfo(1, 0, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, task2Info)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) - sc.listenerBus.postToAll(SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) + post(sc.listenerBus, SparkListenerTaskStart(2, 0, task2Info)) + post(sc.listenerBus, SparkListenerTaskEnd(2, 0, null, Success, task1Info, null)) + post(sc.listenerBus, SparkListenerTaskEnd(2, 0, null, Success, task2Info, null)) assert(adjustRequestedExecutors(manager) === -1) } @@ -352,21 +357,22 @@ class ExecutorAllocationManagerSuite sc = createSparkContext(5, 12, 5) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 8))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 8))) // Remove when numExecutorsTarget is the same as the current number of executors assert(addExecutors(manager) === 1) assert(addExecutors(manager) === 2) (1 to 8).map { i => createTaskInfo(i, i, s"$i") }.foreach { - info => sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, info)) } + info => post(sc.listenerBus, SparkListenerTaskStart(0, 0, info)) } assert(executorIds(manager).size === 8) assert(numExecutorsTarget(manager) === 8) assert(maxNumExecutorsNeeded(manager) == 8) assert(!removeExecutor(manager, "1")) // won't work since numExecutorsTarget == numExecutors // Remove executors when numExecutorsTarget is lower than current number of executors - (1 to 3).map { i => createTaskInfo(i, i, s"$i") }.foreach { - info => sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, Success, info, null)) } + (1 to 3).map { i => createTaskInfo(i, i, s"$i") }.foreach { info => + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, Success, info, null)) + } adjustRequestedExecutors(manager) assert(executorIds(manager).size === 8) assert(numExecutorsTarget(manager) === 5) @@ -378,7 +384,7 @@ class ExecutorAllocationManagerSuite onExecutorRemoved(manager, "3") // numExecutorsTarget is lower than minNumExecutors - sc.listenerBus.postToAll( + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, Success, createTaskInfo(4, 4, "4"), null)) assert(executorIds(manager).size === 5) assert(numExecutorsTarget(manager) === 5) @@ -390,7 +396,7 @@ class ExecutorAllocationManagerSuite test ("interleaving add and remove") { sc = createSparkContext(5, 12, 5) val manager = sc.executorAllocationManager.get - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Add a few executors assert(addExecutors(manager) === 1) @@ -569,7 +575,7 @@ class ExecutorAllocationManagerSuite val clock = new ManualClock(2020L) val manager = sc.executorAllocationManager.get manager.setClock(clock) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1000))) // Scheduler queue backlogged onSchedulerBacklogged(manager) @@ -682,26 +688,26 @@ class ExecutorAllocationManagerSuite // Starting a stage should start the add timer val numTasks = 10 - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, numTasks))) assert(addTime(manager) !== NOT_SET) // Starting a subset of the tasks should not cancel the add timer val taskInfos = (0 to numTasks - 1).map { i => createTaskInfo(i, i, "executor-1") } - taskInfos.tail.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, info)) } + taskInfos.tail.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(0, 0, info)) } assert(addTime(manager) !== NOT_SET) // Starting all remaining tasks should cancel the add timer - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfos.head)) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfos.head)) assert(addTime(manager) === NOT_SET) // Start two different stages // The add timer should be canceled only if all tasks in both stages start running - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, numTasks))) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, numTasks))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(2, numTasks))) assert(addTime(manager) !== NOT_SET) - taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, info)) } + taskInfos.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(1, 0, info)) } assert(addTime(manager) !== NOT_SET) - taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, info)) } + taskInfos.foreach { info => post(sc.listenerBus, SparkListenerTaskStart(2, 0, info)) } assert(addTime(manager) === NOT_SET) } @@ -715,22 +721,22 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).size === 5) // Starting a task cancel the remove timer for that executor - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(1, 1, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(2, 2, "executor-2"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(1, 1, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(2, 2, "executor-2"))) assert(removeTimes(manager).size === 3) assert(!removeTimes(manager).contains("executor-1")) assert(!removeTimes(manager).contains("executor-2")) // Finishing all tasks running on an executor should start the remove timer for that executor - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(0, 0, "executor-1"), new TaskMetrics)) - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(2, 2, "executor-2"), new TaskMetrics)) assert(removeTimes(manager).size === 4) assert(!removeTimes(manager).contains("executor-1")) // executor-1 has not finished yet assert(removeTimes(manager).contains("executor-2")) - sc.listenerBus.postToAll(SparkListenerTaskEnd( + post(sc.listenerBus, SparkListenerTaskEnd( 0, 0, "task-type", Success, createTaskInfo(1, 1, "executor-1"), new TaskMetrics)) assert(removeTimes(manager).size === 5) assert(removeTimes(manager).contains("executor-1")) // executor-1 has now finished @@ -743,13 +749,13 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).isEmpty) // New executors have registered - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(removeTimes(manager).contains("executor-1")) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-2", new ExecutorInfo("host2", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) @@ -757,14 +763,14 @@ class ExecutorAllocationManagerSuite assert(removeTimes(manager).contains("executor-2")) // Existing executors have disconnected - sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-1", "")) + post(sc.listenerBus, SparkListenerExecutorRemoved(0L, "executor-1", "")) assert(executorIds(manager).size === 1) assert(!executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 1) assert(!removeTimes(manager).contains("executor-1")) // Unknown executor has disconnected - sc.listenerBus.postToAll(SparkListenerExecutorRemoved(0L, "executor-3", "")) + post(sc.listenerBus, SparkListenerExecutorRemoved(0L, "executor-3", "")) assert(executorIds(manager).size === 1) assert(removeTimes(manager).size === 1) } @@ -775,8 +781,8 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) @@ -788,15 +794,15 @@ class ExecutorAllocationManagerSuite val manager = sc.executorAllocationManager.get assert(executorIds(manager).isEmpty) assert(removeTimes(manager).isEmpty) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty))) - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1"))) assert(executorIds(manager).size === 1) assert(executorIds(manager).contains("executor-1")) assert(removeTimes(manager).size === 0) - sc.listenerBus.postToAll(SparkListenerExecutorAdded( + post(sc.listenerBus, SparkListenerExecutorAdded( 0L, "executor-2", new ExecutorInfo("host1", 1, Map.empty))) assert(executorIds(manager).size === 2) assert(executorIds(manager).contains("executor-2")) @@ -809,7 +815,7 @@ class ExecutorAllocationManagerSuite sc = createSparkContext(0, 100000, 0) val manager = sc.executorAllocationManager.get val stage1 = createStageInfo(0, 1000) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stage1)) + post(sc.listenerBus, SparkListenerStageSubmitted(stage1)) assert(addExecutors(manager) === 1) assert(addExecutors(manager) === 2) @@ -820,12 +826,12 @@ class ExecutorAllocationManagerSuite onExecutorAdded(manager, s"executor-$i") } assert(executorIds(manager).size === 15) - sc.listenerBus.postToAll(SparkListenerStageCompleted(stage1)) + post(sc.listenerBus, SparkListenerStageCompleted(stage1)) adjustRequestedExecutors(manager) assert(numExecutorsTarget(manager) === 0) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 1000))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 1000))) addExecutors(manager) assert(numExecutorsTarget(manager) === 16) } @@ -842,7 +848,7 @@ class ExecutorAllocationManagerSuite // Verify whether the initial number of executors is kept with no pending tasks assert(numExecutorsTarget(manager) === 3) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(1, 2))) clock.advance(100L) assert(maxNumExecutorsNeeded(manager) === 2) @@ -892,7 +898,7 @@ class ExecutorAllocationManagerSuite Seq.empty ) val stageInfo1 = createStageInfo(1, 5, localityPreferences1) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + post(sc.listenerBus, SparkListenerStageSubmitted(stageInfo1)) assert(localityAwareTasks(manager) === 3) assert(hostToLocalTaskCount(manager) === @@ -904,13 +910,13 @@ class ExecutorAllocationManagerSuite Seq.empty ) val stageInfo2 = createStageInfo(2, 3, localityPreferences2) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo2)) + post(sc.listenerBus, SparkListenerStageSubmitted(stageInfo2)) assert(localityAwareTasks(manager) === 5) assert(hostToLocalTaskCount(manager) === Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2)) - sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo1)) + post(sc.listenerBus, SparkListenerStageCompleted(stageInfo1)) assert(localityAwareTasks(manager) === 2) assert(hostToLocalTaskCount(manager) === Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) @@ -921,16 +927,16 @@ class ExecutorAllocationManagerSuite val manager = sc.executorAllocationManager.get assert(maxNumExecutorsNeeded(manager) === 0) - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 1))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 1))) assert(maxNumExecutorsNeeded(manager) === 1) val taskInfo = createTaskInfo(1, 1, "executor-1") - sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfo)) + post(sc.listenerBus, SparkListenerTaskStart(0, 0, taskInfo)) assert(maxNumExecutorsNeeded(manager) === 1) // If the task is failed, we expect it to be resubmitted later. val taskEndReason = ExceptionFailure(null, null, null, null, None) - sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) + post(sc.listenerBus, SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null)) assert(maxNumExecutorsNeeded(manager) === 1) } @@ -942,7 +948,7 @@ class ExecutorAllocationManagerSuite // Allocation manager is reset when adding executor requests are sent without reporting back // executor added. - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 10))) assert(addExecutors(manager) === 1) assert(numExecutorsTarget(manager) === 2) @@ -957,7 +963,7 @@ class ExecutorAllocationManagerSuite assert(executorIds(manager) === Set.empty) // Allocation manager is reset when executors are added. - sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, 10))) + post(sc.listenerBus, SparkListenerStageSubmitted(createStageInfo(0, 10))) addExecutors(manager) addExecutors(manager) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4d69ce844d2ea..ad801bf8519a6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -897,6 +897,71 @@ class SparkSubmitSuite sysProps("spark.submit.pyFiles") should (startWith("/")) } + test("download remote resource if it is not supported by yarn service") { + testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + } + + test("avoid downloading remote resource if it is supported by yarn service") { + testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + } + + test("force download from blacklisted schemes") { + testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + } + + private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, + supportMockHttpFs: Boolean): Unit = { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + if (supportMockHttpFs) { + hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) + hadoopConf.set("fs.http.impl.disable.cache", "true") + } + + val tmpDir = Utils.createTempDir() + val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) + val tmpS3Jar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpS3JarPath = s"s3a://${new File(tmpS3Jar.toURI).getAbsolutePath}" + val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", + s"s3a://$mainResource" + ) ++ ( + if (isHttpSchemeBlacklisted) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") + } else { + Nil + } + ) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + + val jars = sysProps("spark.yarn.dist.jars").split(",").toSet + + // The URI of remote S3 resource should still be remote. + assert(jars.contains(tmpS3JarPath)) + + if (supportMockHttpFs) { + // If Http FS is supported by yarn service, the URI of remote http resource should + // still be remote. + assert(jars.contains(tmpHttpJarPath)) + } else { + // If Http FS is not supported by yarn service, or http scheme is configured to be force + // downloading, the URI of remote http resource should be changed to a local one. + val jarName = new File(tmpHttpJar.toURI).getName + val localHttpJar = jars.filter(_.contains(jarName)) + localHttpJar.size should be(1) + localHttpJar.head should startWith("file:") + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 703fc1b34c387..6222e576d1ce9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -751,7 +751,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Helper functions to extract commonly used code in Fetch Failure test cases private def setupStageAbortTest(sc: SparkContext) { - sc.listenerBus.addListener(new EndListener()) + sc.listenerBus.addToSharedQueue(new EndListener()) ended = false jobResult = null } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 0afd07b851cf9..6b42775ccb0f6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -164,9 +164,9 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite eventLogger.start() listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem])) - listenerBus.addListener(eventLogger) - listenerBus.postToAll(applicationStart) - listenerBus.postToAll(applicationEnd) + listenerBus.addToEventLogQueue(eventLogger) + listenerBus.post(applicationStart) + listenerBus.post(applicationEnd) listenerBus.stop() eventLogger.stop() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 995df1dd52010..d061c7845f4a6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -34,6 +34,8 @@ import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { + import LiveListenerBus._ + /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 @@ -42,18 +44,28 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match private val mockSparkContext: SparkContext = Mockito.mock(classOf[SparkContext]) private val mockMetricsSystem: MetricsSystem = Mockito.mock(classOf[MetricsSystem]) + private def numDroppedEvents(bus: LiveListenerBus): Long = { + bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount + } + + private def queueSize(bus: LiveListenerBus): Int = { + bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() + .asInstanceOf[Int] + } + + private def eventProcessingTimeCount(bus: LiveListenerBus): Long = { + bus.metrics.metricRegistry.timer(s"queue.$SHARED_QUEUE.listenerProcessingTime").getCount() + } + test("don't call sc.stop in listener") { sc = new SparkContext("local", "SparkListenerSuite", new SparkConf()) val listener = new SparkContextStoppingListener(sc) - val bus = new LiveListenerBus(sc.conf) - bus.addListener(listener) - // Starting listener bus should flush all buffered events - bus.start(sc, sc.env.metricsSystem) - bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + sc.listenerBus.addToSharedQueue(listener) + sc.listenerBus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + sc.stop() - bus.stop() assert(listener.sparkExSeen) } @@ -61,13 +73,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addListener(counter) + bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) - assert(bus.metrics.numDroppedEvents.getCount === 0) - assert(bus.metrics.queueSize.getValue === 0) - assert(bus.metrics.eventProcessingTime.getCount === 0) + assert(numDroppedEvents(bus) === 0) + assert(queueSize(bus) === 0) + assert(eventProcessingTimeCount(bus) === 0) // Post five events: (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -75,7 +87,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(bus.metrics.queueSize.getValue === 5) + assert(queueSize(bus) === 5) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -83,18 +95,14 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(bus.metrics.queueSize.getValue === 0) - assert(bus.metrics.eventProcessingTime.getCount === 5) + assert(queueSize(bus) === 0) + assert(eventProcessingTimeCount(bus) === 5) // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } assert(counter.count === 5) - assert(bus.metrics.numEventsPosted.getCount === 5) - - // Make sure per-listener-class timers were created: - assert(bus.metrics.getTimerForListenerClass( - classOf[BasicJobCounter].asSubclass(classOf[SparkListenerInterface])).get.getCount == 5) + assert(eventProcessingTimeCount(bus) === 5) // Listener bus must not be started twice intercept[IllegalStateException] { @@ -135,7 +143,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val bus = new LiveListenerBus(new SparkConf()) val blockingListener = new BlockingListener - bus.addListener(blockingListener) + bus.addToSharedQueue(blockingListener) bus.start(mockSparkContext, mockMetricsSystem) bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) @@ -168,7 +176,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val listenerStarted = new Semaphore(0) val listenerWait = new Semaphore(0) - bus.addListener(new SparkListener { + bus.addToSharedQueue(new SparkListener { override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { listenerStarted.release() listenerWait.acquire() @@ -180,20 +188,19 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(bus.metrics.queueSize.getValue === 0) - assert(bus.metrics.numDroppedEvents.getCount === 0) + assert(queueSize(bus) === 0) + assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(bus.metrics.queueSize.getValue === 1) - assert(bus.metrics.numDroppedEvents.getCount === 0) + assert(queueSize(bus) === 1) + assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(bus.metrics.queueSize.getValue === 1) - assert(bus.metrics.numDroppedEvents.getCount === 1) - + assert(queueSize(bus) === 1) + assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: listenerWait.release(2) @@ -419,9 +426,9 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val bus = new LiveListenerBus(new SparkConf()) // Propagate events to bad listener first - bus.addListener(badListener) - bus.addListener(jobCounter1) - bus.addListener(jobCounter2) + bus.addToSharedQueue(badListener) + bus.addToSharedQueue(jobCounter1) + bus.addToSharedQueue(jobCounter2) bus.start(mockSparkContext, mockMetricsSystem) // Post events to all listeners, and wait until the queue is drained @@ -429,7 +436,6 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) // The exception should be caught, and the event should be propagated to other listeners - assert(bus.listenerThreadIsAlive) assert(jobCounter1.count === 5) assert(jobCounter2.count === 5) } @@ -449,6 +455,31 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } + test("add and remove listeners to/from LiveListenerBus queues") { + val bus = new LiveListenerBus(new SparkConf(false)) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val counter3 = new BasicJobCounter() + + bus.addToSharedQueue(counter1) + bus.addToStatusQueue(counter2) + bus.addToStatusQueue(counter3) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 3) + + bus.removeListener(counter1) + assert(bus.activeQueues() === Set(APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + + bus.removeListener(counter2) + assert(bus.activeQueues() === Set(APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 1) + + bus.removeListener(counter3) + assert(bus.activeQueues().isEmpty) + assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) + } + /** * Assert that the given list of numbers has an average that is greater than zero. */ diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 1cb52593e7060..79f02f2e50bbd 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.storage._ * Test various functionality in the StorageListener that supports the StorageTab. */ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { - private var bus: LiveListenerBus = _ + private var bus: SparkListenerBus = _ private var storageStatusListener: StorageStatusListener = _ private var storageListener: StorageListener = _ private val memAndDisk = StorageLevel.MEMORY_AND_DISK @@ -43,7 +43,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter { before { val conf = new SparkConf() - bus = new LiveListenerBus(conf) + bus = new ReplayListenerBus() storageStatusListener = new StorageStatusListener(conf) storageListener = new StorageListener(storageStatusListener) bus.addListener(storageStatusListener) diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9ac753861dd84..e534e38213fb1 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -139,10 +139,10 @@ machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.3.0-shaded-protobuf.jar -metrics-core-3.1.2.jar -metrics-graphite-3.1.2.jar -metrics-json-3.1.2.jar -metrics-jvm-3.1.2.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.9.9.Final.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index d39747e9ee058..02c5a19d173be 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -140,10 +140,10 @@ machinist_2.11-0.6.1.jar macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.3.0-shaded-protobuf.jar -metrics-core-3.1.2.jar -metrics-graphite-3.1.2.jar -metrics-json-3.1.2.jar -metrics-jvm-3.1.2.jar +metrics-core-3.1.5.jar +metrics-graphite-3.1.5.jar +metrics-json-3.1.5.jar +metrics-jvm-3.1.5.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.9.9.Final.jar diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e4a74556d4f26..432639588cc2b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,6 +211,15 @@ To use a custom metrics.properties for the application master and executors, upd Comma-separated list of jars to be placed in the working directory of each executor. + + spark.yarn.dist.forceDownloadSchemes + (none) + + Comma-separated list of schemes for which files will be downloaded to the local disk prior to + being added to YARN's distributed cache. For use in cases where the YARN service does not + support schemes that are supported by Spark, like http, https and ftp. + + spark.executor.instances 2 diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 95d704014742c..5db60cc996e75 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1333,7 +1333,7 @@ the following case-insensitive options: customSchema - The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING"). The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading. + The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING". You can also specify partial fields, and the others use the default type mapping. For example, "id DECIMAL(38, 0)". The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading. diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index f6095e26f435c..fe3306e1e50d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -337,14 +337,17 @@ object Word2VecModel extends MLReadable[Word2VecModel] { DefaultParamsWriter.saveMetadata(instance, path, sc) val wordVectors = instance.wordVectors.getVectors - val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) } val dataPath = new Path(path, "data").toString val bufferSizeInBytes = Utils.byteStringAsBytes( sc.conf.get("spark.kryoserializer.buffer.max", "64m")) val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions( bufferSizeInBytes, instance.wordVectors.wordIndex.size, instance.getVectorSize) - sparkSession.createDataFrame(dataSeq) + val spark = sparkSession + import spark.implicits._ + spark.createDataset[(String, Array[Float])](wordVectors.toSeq) .repartition(numPartitions) + .map { case (word, vector) => Data(word, vector) } + .toDF() .write .parquet(dataPath) } diff --git a/pom.xml b/pom.xml index af511c3e2e5df..42ffe2fb405f5 100644 --- a/pom.xml +++ b/pom.xml @@ -138,7 +138,7 @@ 0.8.4 2.4.0 2.0.8 - 3.1.2 + 3.1.5 1.7.7 hadoop2 0.9.3 @@ -2637,6 +2637,13 @@ + + kubernetes + + resource-managers/kubernetes/core + + + hive-thriftserver diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7046043e0376..a33f6dcf31fc0 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -997,12 +997,20 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): def show_profiles(self): """ Print the profile stats to stdout """ - self.profiler_collector.show_profiles() + if self.profiler_collector is not None: + self.profiler_collector.show_profiles() + else: + raise RuntimeError("'spark.python.profile' configuration must be set " + "to 'true' to enable Python profile.") def dump_profiles(self, path): """ Dump the profile stats into directory `path` """ - self.profiler_collector.dump_profiles(path) + if self.profiler_collector is not None: + self.profiler_collector.dump_profiles(path) + else: + raise RuntimeError("'spark.python.profile' configuration must be set " + "to 'true' to enable Python profile.") def getConf(self): conf = SparkConf() diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 7cb8d62f212cb..09cdf9b6a629a 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -146,8 +146,7 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", super(BinaryClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) - self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", - metricName="areaUnderROC") + self._setDefault(metricName="areaUnderROC") kwargs = self._input_kwargs self._set(**kwargs) @@ -224,8 +223,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(RegressionEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) - self._setDefault(predictionCol="prediction", labelCol="label", - metricName="rmse") + self._setDefault(metricName="rmse") kwargs = self._input_kwargs self._set(**kwargs) @@ -297,8 +295,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", super(MulticlassClassificationEvaluator, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) - self._setDefault(predictionCol="prediction", labelCol="label", - metricName="f1") + self._setDefault(metricName="f1") kwargs = self._input_kwargs self._set(**kwargs) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d5c2a7518b18f..660b19ad2a7c4 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -97,7 +97,7 @@ def load_stream(self, stream): def _load_stream_without_unbatching(self, stream): """ - Return an iterator of deserialized batches (lists) of objects from the input stream. + Return an iterator of deserialized batches (iterable) of objects from the input stream. if the serializer does not operate on batches the default implementation returns an iterator of single element lists. """ @@ -343,6 +343,10 @@ def _load_stream_without_unbatching(self, stream): key_batch_stream = self.key_ser._load_stream_without_unbatching(stream) val_batch_stream = self.val_ser._load_stream_without_unbatching(stream) for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream): + # For double-zipped RDDs, the batches can be iterators from other PairDeserializer, + # instead of lists. We need to convert them to lists if needed. + key_batch = key_batch if hasattr(key_batch, '__len__') else list(key_batch) + val_batch = val_batch if hasattr(val_batch, '__len__') else list(val_batch) if len(key_batch) != len(val_batch): raise ValueError("Can not deserialize PairRDD with different number of items" " in batches: (%d, %d)" % (len(key_batch), len(val_batch))) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 0e76182e0e02d..399bef02d9cc4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1884,9 +1884,9 @@ def json_tuple(col, *fields): @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]] - of [[StructType]]s with the specified schema. Returns `null`, in the case of an unparseable - string. + Parses a column containing a JSON string into a :class:`StructType` or :class:`ArrayType` + of :class:`StructType`\\s with the specified schema. Returns `null`, in the case of an + unparseable string. :param col: string column in json format :param schema: a StructType or ArrayType of StructType to use when parsing the json column. @@ -1921,10 +1921,12 @@ def from_json(col, schema, options={}): @since(2.1) def to_json(col, options={}): """ - Converts a column containing a [[StructType]] or [[ArrayType]] of [[StructType]]s into a - JSON string. Throws an exception, in the case of an unsupported type. + Converts a column containing a :class:`StructType`, :class:`ArrayType` of + :class:`StructType`\\s, a :class:`MapType` or :class:`ArrayType` of :class:`MapType`\\s + into a JSON string. Throws an exception, in the case of an unsupported type. - :param col: name of column containing the struct or array of the structs + :param col: name of column containing the struct, array of the structs, the map or + array of the maps. :param options: options to control converting. accepts the same options as the json datasource >>> from pyspark.sql import Row @@ -1937,6 +1939,14 @@ def to_json(col, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(to_json(df.value).alias("json")).collect() [Row(json=u'[{"age":2,"name":"Alice"},{"age":3,"name":"Bob"}]')] + >>> data = [(1, {"name": "Alice"})] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'{"name":"Alice"}')] + >>> data = [(1, [{"name": "Alice"}, {"name": "Bob"}])] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(to_json(df.value).alias("json")).collect() + [Row(json=u'[{"name":"Alice"},{"name":"Bob"}]')] """ sc = SparkContext._active_spark_context diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 920cf009f599d..aaf520fa8019f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -483,7 +483,9 @@ def __init__(self, fields=None): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeAnyField = any(f.needConversion() for f in self) + # Precalculated list of fields that need conversion with fromInternal/toInternal functions + self._needConversion = [f.needConversion() for f in self] + self._needSerializeAnyField = any(self._needConversion) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -528,7 +530,9 @@ def add(self, field, data_type=None, nullable=True, metadata=None): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) - self._needSerializeAnyField = any(f.needConversion() for f in self) + # Precalculated list of fields that need conversion with fromInternal/toInternal functions + self._needConversion = [f.needConversion() for f in self] + self._needSerializeAnyField = any(self._needConversion) return self def __iter__(self): @@ -590,13 +594,17 @@ def toInternal(self, obj): return if self._needSerializeAnyField: + # Only calling toInternal function for fields that need conversion if isinstance(obj, dict): - return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) + return tuple(f.toInternal(obj.get(n)) if c else obj.get(n) + for n, f, c in zip(self.names, self.fields, self._needConversion)) elif isinstance(obj, (tuple, list)): - return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + return tuple(f.toInternal(v) if c else v + for f, v, c in zip(self.fields, obj, self._needConversion)) elif hasattr(obj, "__dict__"): d = obj.__dict__ - return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields)) + return tuple(f.toInternal(d.get(n)) if c else d.get(n) + for n, f, c in zip(self.names, self.fields, self._needConversion)) else: raise ValueError("Unexpected tuple %r with StructType" % obj) else: @@ -619,7 +627,9 @@ def fromInternal(self, obj): # it's already converted by pickler return obj if self._needSerializeAnyField: - values = [f.fromInternal(v) for f, v in zip(self.fields, obj)] + # Only calling fromInternal function for fields that need conversion + values = [f.fromInternal(v) if c else v + for f, v, c in zip(self.fields, obj, self._needConversion)] else: values = obj return _create_row(self.names, values) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 000dd1eb8e481..da99872da2f0e 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -644,6 +644,18 @@ def test_cartesian_chaining(self): set([(x, (y, y)) for x in range(10) for y in range(10)]) ) + def test_zip_chaining(self): + # Tests for SPARK-21985 + rdd = self.sc.parallelize('abc', 2) + self.assertSetEqual( + set(rdd.zip(rdd).zip(rdd).collect()), + set([((x, x), x) for x in 'abc']) + ) + self.assertSetEqual( + set(rdd.zip(rdd.zip(rdd)).collect()), + set([(x, (x, x)) for x in 'abc']) + ) + def test_deleting_input_files(self): # Regression test for SPARK-1025 tempFile = tempfile.NamedTemporaryFile(delete=False) @@ -1284,6 +1296,22 @@ def heavy_foo(x): rdd.foreach(heavy_foo) +class ProfilerTests2(unittest.TestCase): + def test_profiler_disabled(self): + sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false")) + try: + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.show_profiles()) + self.assertRaisesRegexp( + RuntimeError, + "'spark.python.profile' configuration must be set", + lambda: sc.dump_profiles("/tmp/abc")) + finally: + sc.stop() + + class InputFormatTests(ReusedPySparkTestCase): @classmethod diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml new file mode 100644 index 0000000000000..1637c0f7aa716 --- /dev/null +++ b/resource-managers/kubernetes/core/pom.xml @@ -0,0 +1,102 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.3.0-SNAPSHOT + ../../../pom.xml + + + spark-kubernetes_2.11 + jar + Spark Project Kubernetes + + kubernetes + 2.2.13 + + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + + io.fabric8 + kubernetes-client + ${kubernetes.client.version} + + + com.fasterxml.jackson.core + jackson-core + + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.core + jackson-annotations + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + + + + + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${fasterxml.jackson.version} + + + + + com.google.guava + guava + + + + + org.mockito + mockito-core + test + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala new file mode 100644 index 0000000000000..aafb6f3aabe6d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/ConfigurationUtils.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.kubernetes + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.OptionalConfigEntry + +object ConfigurationUtils extends Logging { + def parseKeyValuePairs( + maybeKeyValues: Option[String], + configKey: String, + keyValueType: String): Map[String, String] = { + + maybeKeyValues.map(keyValues => { + keyValues.split(",").map(_.trim).filterNot(_.isEmpty).map(keyValue => { + keyValue.split("=", 2).toSeq match { + case Seq(k, v) => + (k, v) + case _ => + throw new SparkException(s"Custom $keyValueType set by $configKey must be a" + + s" comma-separated list of key-value pairs, with format =." + + s" Got value: $keyValue. All values: $keyValues") + } + }).toMap + }).getOrElse(Map.empty[String, String]) + } + + def parsePrefixedKeyValuePairs( + sparkConf: SparkConf, + prefix: String, + configType: String): Map[String, String] = { + val fromPrefix = sparkConf.getAllWithPrefix(prefix) + fromPrefix.groupBy(_._1).foreach { + case (key, values) => + require(values.size == 1, + s"Cannot have multiple values for a given $configType key, got key $key with" + + s" values $values") + } + fromPrefix.toMap + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/OptionRequirements.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/OptionRequirements.scala new file mode 100644 index 0000000000000..eda43de0a9a5b --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/OptionRequirements.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.kubernetes + +private[spark] object OptionRequirements { + + def requireBothOrNeitherDefined( + opt1: Option[_], + opt2: Option[_], + errMessageWhenFirstIsMissing: String, + errMessageWhenSecondIsMissing: String): Unit = { + requireSecondIfFirstIsDefined(opt1, opt2, errMessageWhenSecondIsMissing) + requireSecondIfFirstIsDefined(opt2, opt1, errMessageWhenFirstIsMissing) + } + + def requireSecondIfFirstIsDefined( + opt1: Option[_], opt2: Option[_], errMessageWhenSecondIsMissing: String): Unit = { + opt1.foreach { _ => + require(opt2.isDefined, errMessageWhenSecondIsMissing) + } + } + + def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { + opt1.foreach { _ => require(opt2.isEmpty, errMessage) } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/SparkKubernetesClientFactory.scala new file mode 100644 index 0000000000000..d2729a2db2fa0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/SparkKubernetesClientFactory.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.kubernetes + +import java.io.File + +import com.google.common.base.Charsets +import com.google.common.io.Files +import io.fabric8.kubernetes.client.{Config, ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.utils.HttpClientUtils +import okhttp3.Dispatcher + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.util.ThreadUtils + +/** + * Spark-opinionated builder for Kubernetes clients. It uses a prefix plus common suffixes to + * parse configuration keys, similar to the manner in which Spark's SecurityManager parses SSL + * options for different components. + */ +private[spark] object SparkKubernetesClientFactory { + + def createKubernetesClient( + master: String, + namespace: Option[String], + kubernetesAuthConfPrefix: String, + sparkConf: SparkConf, + maybeServiceAccountToken: Option[File], + maybeServiceAccountCaCert: Option[File]): KubernetesClient = { + val oauthTokenFileConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_FILE_CONF_SUFFIX" + val oauthTokenConf = s"$kubernetesAuthConfPrefix.$OAUTH_TOKEN_CONF_SUFFIX" + val oauthTokenFile = sparkConf.getOption(oauthTokenFileConf) + .map(new File(_)) + .orElse(maybeServiceAccountToken) + val oauthTokenValue = sparkConf.getOption(oauthTokenConf) + OptionRequirements.requireNandDefined( + oauthTokenFile, + oauthTokenValue, + s"Cannot specify OAuth token through both a file $oauthTokenFileConf and a" + + s" value $oauthTokenConf.") + + val caCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CA_CERT_FILE_CONF_SUFFIX") + .orElse(maybeServiceAccountCaCert.map(_.getAbsolutePath)) + val clientKeyFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_KEY_FILE_CONF_SUFFIX") + val clientCertFile = sparkConf + .getOption(s"$kubernetesAuthConfPrefix.$CLIENT_CERT_FILE_CONF_SUFFIX") + val dispatcher = new Dispatcher( + ThreadUtils.newDaemonCachedThreadPool("kubernetes-dispatcher")) + val config = new ConfigBuilder() + .withApiVersion("v1") + .withMasterUrl(master) + .withWebsocketPingInterval(0) + .withOption(oauthTokenValue) { + (token, configBuilder) => configBuilder.withOauthToken(token) + }.withOption(oauthTokenFile) { + (file, configBuilder) => + configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8)) + }.withOption(caCertFile) { + (file, configBuilder) => configBuilder.withCaCertFile(file) + }.withOption(clientKeyFile) { + (file, configBuilder) => configBuilder.withClientKeyFile(file) + }.withOption(clientCertFile) { + (file, configBuilder) => configBuilder.withClientCertFile(file) + }.withOption(namespace) { + (ns, configBuilder) => configBuilder.withNamespace(ns) + }.build() + val baseHttpClient = HttpClientUtils.createHttpClient(config) + val httpClientWithCustomDispatcher = baseHttpClient.newBuilder() + .dispatcher(dispatcher) + .build() + new DefaultKubernetesClient(httpClientWithCustomDispatcher, config) + } + + private implicit class OptionConfigurableConfigBuilder(configBuilder: ConfigBuilder) { + + def withOption[T] + (option: Option[T]) + (configurator: ((T, ConfigBuilder) => ConfigBuilder)): OptionConfigurableConfigBuilder = { + new OptionConfigurableConfigBuilder(option.map { opt => + configurator(opt, configBuilder) + }.getOrElse(configBuilder)) + } + + def build(): Config = configBuilder.build() + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala new file mode 100644 index 0000000000000..2e37749df4233 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/config.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.kubernetes + +import org.apache.spark.{SPARK_VERSION => sparkVersion} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigBuilder +import org.apache.spark.network.util.ByteUnit + +package object config extends Logging { + + private[spark] val KUBERNETES_NAMESPACE = + ConfigBuilder("spark.kubernetes.namespace") + .doc("The namespace that will be used for running the driver and executor pods. When using" + + " spark-submit in cluster mode, this can also be passed to spark-submit via the" + + " --kubernetes-namespace command line argument.") + .stringConf + .createWithDefault("default") + + private[spark] val DRIVER_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.driver.docker.image") + .doc("Docker image to use for the driver. Specify this using the standard Docker tag format.") + .stringConf + .createWithDefault(s"spark-driver:$sparkVersion") + + private[spark] val EXECUTOR_DOCKER_IMAGE = + ConfigBuilder("spark.kubernetes.executor.docker.image") + .doc("Docker image to use for the executors. Specify this using the standard Docker tag" + + " format.") + .stringConf + .createWithDefault(s"spark-executor:$sparkVersion") + + private[spark] val DOCKER_IMAGE_PULL_POLICY = + ConfigBuilder("spark.kubernetes.docker.image.pullPolicy") + .doc("Docker image pull policy when pulling any docker image in Kubernetes integration") + .stringConf + .createWithDefault("IfNotPresent") + + private[spark] val APISERVER_AUTH_SUBMISSION_CONF_PREFIX = + "spark.kubernetes.authenticate.submission" + private[spark] val APISERVER_AUTH_DRIVER_CONF_PREFIX = + "spark.kubernetes.authenticate.driver" + private[spark] val APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX = + "spark.kubernetes.authenticate.driver.mounted" + private[spark] val OAUTH_TOKEN_CONF_SUFFIX = "oauthToken" + private[spark] val OAUTH_TOKEN_FILE_CONF_SUFFIX = "oauthTokenFile" + private[spark] val CLIENT_KEY_FILE_CONF_SUFFIX = "clientKeyFile" + private[spark] val CLIENT_CERT_FILE_CONF_SUFFIX = "clientCertFile" + private[spark] val CA_CERT_FILE_CONF_SUFFIX = "caCertFile" + + private[spark] val KUBERNETES_SERVICE_ACCOUNT_NAME = + ConfigBuilder(s"$APISERVER_AUTH_DRIVER_CONF_PREFIX.serviceAccountName") + .doc("Service account that is used when running the driver pod. The driver pod uses" + + " this service account when requesting executor pods from the API server. If specific" + + " credentials are given for the driver pod to use, the driver will favor" + + " using those credentials instead.") + .stringConf + .createOptional + + // Note that while we set a default for this when we start up the + // scheduler, the specific default value is dynamically determined + // based on the executor memory. + private[spark] val KUBERNETES_EXECUTOR_MEMORY_OVERHEAD = + ConfigBuilder("spark.kubernetes.executor.memoryOverhead") + .doc("The amount of off-heap memory (in megabytes) to be allocated per executor. This" + + " is memory that accounts for things like VM overheads, interned strings, other native" + + " overheads, etc. This tends to grow with the executor size. (typically 6-10%).") + .bytesConf(ByteUnit.MiB) + .createOptional + + private[spark] val KUBERNETES_DRIVER_MEMORY_OVERHEAD = + ConfigBuilder("spark.kubernetes.driver.memoryOverhead") + .doc("The amount of off-heap memory (in megabytes) to be allocated for the driver and the" + + " driver submission server. This is memory that accounts for things like VM overheads," + + " interned strings, other native overheads, etc. This tends to grow with the driver's" + + " memory size (typically 6-10%).") + .bytesConf(ByteUnit.MiB) + .createOptional + + private[spark] val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." + private[spark] val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." + private[spark] val KUBERNETES_EXECUTOR_LABEL_PREFIX = "spark.kubernetes.executor.label." + private[spark] val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." + + private[spark] val KUBERNETES_DRIVER_LABELS = + ConfigBuilder("spark.kubernetes.driver.labels") + .doc("Custom labels that will be added to the driver pod. This should be a comma-separated" + + " list of label key-value pairs, where each label is in the format key=value. Note that" + + " Spark also adds its own labels to the driver pod for bookkeeping purposes.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." + + private[spark] val KUBERNETES_DRIVER_ANNOTATIONS = + ConfigBuilder("spark.kubernetes.driver.annotations") + .doc("Custom annotations that will be added to the driver pod. This should be a" + + " comma-separated list of annotation key-value pairs, where each annotation is in the" + + " format key=value.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_EXECUTOR_LABELS = + ConfigBuilder("spark.kubernetes.executor.labels") + .doc("Custom labels that will be added to the executor pods. This should be a" + + " comma-separated list of label key-value pairs, where each label is in the format" + + " key=value.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_EXECUTOR_ANNOTATIONS = + ConfigBuilder("spark.kubernetes.executor.annotations") + .doc("Custom annotations that will be added to the executor pods. This should be a" + + " comma-separated list of annotation key-value pairs, where each annotation is in the" + + " format key=value.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." + private[spark] val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." + + private[spark] val KUBERNETES_DRIVER_POD_NAME = + ConfigBuilder("spark.kubernetes.driver.pod.name") + .doc("Name of the driver pod.") + .stringConf + .createOptional + + private[spark] val KUBERNETES_EXECUTOR_POD_NAME_PREFIX = + ConfigBuilder("spark.kubernetes.executor.podNamePrefix") + .doc("Prefix to use in front of the executor pod names.") + .internal() + .stringConf + .createWithDefault("spark") + + private[spark] val KUBERNETES_ALLOCATION_BATCH_SIZE = + ConfigBuilder("spark.kubernetes.allocation.batch.size") + .doc("Number of pods to launch at once in each round of dynamic allocation. ") + .intConf + .createWithDefault(5) + + private[spark] val KUBERNETES_ALLOCATION_BATCH_DELAY = + ConfigBuilder("spark.kubernetes.allocation.batch.delay") + .doc("Number of seconds to wait between each round of executor allocation. ") + .longConf + .createWithDefault(1) + + private[spark] val INIT_CONTAINER_JARS_DOWNLOAD_LOCATION = + ConfigBuilder("spark.kubernetes.mountdependencies.jarsDownloadDir") + .doc("Location to download jars to in the driver and executors. When using" + + " spark-submit, this directory must be empty and will be mounted as an empty directory" + + " volume on the driver and executor pod.") + .stringConf + .createWithDefault("/var/spark-data/spark-jars") + + private[spark] val KUBERNETES_EXECUTOR_LIMIT_CORES = + ConfigBuilder("spark.kubernetes.executor.limit.cores") + .doc("Specify the hard cpu limit for a single executor pod") + .stringConf + .createOptional + + private[spark] val KUBERNETES_NODE_SELECTOR_PREFIX = "spark.kubernetes.node.selector." + + private[spark] def resolveK8sMaster(rawMasterString: String): String = { + if (!rawMasterString.startsWith("k8s://")) { + throw new IllegalArgumentException("Master URL should start with k8s:// in Kubernetes mode.") + } + val masterWithoutK8sPrefix = rawMasterString.replaceFirst("k8s://", "") + if (masterWithoutK8sPrefix.startsWith("http://") + || masterWithoutK8sPrefix.startsWith("https://")) { + masterWithoutK8sPrefix + } else { + val resolvedURL = s"https://$masterWithoutK8sPrefix" + logDebug(s"No scheme specified for kubernetes master URL, so defaulting to https. Resolved" + + s" URL is $resolvedURL") + resolvedURL + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala new file mode 100644 index 0000000000000..07bf56e8bf0dc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/constants.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.kubernetes + +package object constants { + // Labels + private[spark] val SPARK_APP_ID_LABEL = "spark-app-selector" + private[spark] val SPARK_EXECUTOR_ID_LABEL = "spark-exec-id" + private[spark] val SPARK_ROLE_LABEL = "spark-role" + private[spark] val SPARK_POD_DRIVER_ROLE = "driver" + private[spark] val SPARK_POD_EXECUTOR_ROLE = "executor" + + // Credentials secrets + private[spark] val DRIVER_CREDENTIALS_SECRETS_BASE_DIR = + "/mnt/secrets/spark-kubernetes-credentials" + private[spark] val DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME = "ca-cert" + private[spark] val DRIVER_CREDENTIALS_CA_CERT_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME" + private[spark] val DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME = "client-key" + private[spark] val DRIVER_CREDENTIALS_CLIENT_KEY_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME" + private[spark] val DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME = "client-cert" + private[spark] val DRIVER_CREDENTIALS_CLIENT_CERT_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME" + private[spark] val DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME = "oauth-token" + private[spark] val DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH = + s"$DRIVER_CREDENTIALS_SECRETS_BASE_DIR/$DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME" + private[spark] val DRIVER_CREDENTIALS_SECRET_VOLUME_NAME = "kubernetes-credentials" + + // Default and fixed ports + private[spark] val DEFAULT_DRIVER_PORT = 7078 + private[spark] val DEFAULT_BLOCKMANAGER_PORT = 7079 + private[spark] val DEFAULT_UI_PORT = 4040 + private[spark] val BLOCK_MANAGER_PORT_NAME = "blockmanager" + private[spark] val DRIVER_PORT_NAME = "driver-rpc-port" + private[spark] val EXECUTOR_PORT_NAME = "executor" + + // Environment Variables + private[spark] val ENV_EXECUTOR_PORT = "SPARK_EXECUTOR_PORT" + private[spark] val ENV_DRIVER_URL = "SPARK_DRIVER_URL" + private[spark] val ENV_EXECUTOR_CORES = "SPARK_EXECUTOR_CORES" + private[spark] val ENV_EXECUTOR_MEMORY = "SPARK_EXECUTOR_MEMORY" + private[spark] val ENV_APPLICATION_ID = "SPARK_APPLICATION_ID" + private[spark] val ENV_EXECUTOR_ID = "SPARK_EXECUTOR_ID" + private[spark] val ENV_EXECUTOR_POD_IP = "SPARK_EXECUTOR_POD_IP" + private[spark] val ENV_DRIVER_MEMORY = "SPARK_DRIVER_MEMORY" + private[spark] val ENV_SUBMIT_EXTRA_CLASSPATH = "SPARK_SUBMIT_EXTRA_CLASSPATH" + private[spark] val ENV_EXECUTOR_EXTRA_CLASSPATH = "SPARK_EXECUTOR_EXTRA_CLASSPATH" + private[spark] val ENV_MOUNTED_CLASSPATH = "SPARK_MOUNTED_CLASSPATH" + private[spark] val ENV_DRIVER_MAIN_CLASS = "SPARK_DRIVER_CLASS" + private[spark] val ENV_DRIVER_ARGS = "SPARK_DRIVER_ARGS" + private[spark] val ENV_DRIVER_JAVA_OPTS = "SPARK_DRIVER_JAVA_OPTS" + private[spark] val ENV_MOUNTED_FILES_DIR = "SPARK_MOUNTED_FILES_DIR" + private[spark] val ENV_PYSPARK_FILES = "PYSPARK_FILES" + private[spark] val ENV_PYSPARK_PRIMARY = "PYSPARK_PRIMARY" + private[spark] val ENV_JAVA_OPT_PREFIX = "SPARK_JAVA_OPT_" + + // Miscellaneous + private[spark] val DRIVER_CONTAINER_NAME = "spark-kubernetes-driver" + private[spark] val KUBERNETES_MASTER_INTERNAL_URL = "https://kubernetes.default.svc" + private[spark] val MEMORY_OVERHEAD_FACTOR = 0.10 + private[spark] val MEMORY_OVERHEAD_MIN_MIB = 384L +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactory.scala new file mode 100644 index 0000000000000..8d493606d1f60 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactory.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, Pod, PodBuilder, QuantityBuilder} +import org.apache.commons.io.FilenameUtils + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.deploy.kubernetes.ConfigurationUtils +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.util.Utils + +// Configures executor pods. Construct one of these with a SparkConf to set up properties that are +// common across all executors. Then, pass in dynamic parameters into createExecutorPod. +private[spark] trait ExecutorPodFactory { + def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod +} + +private[spark] class ExecutorPodFactoryImpl(sparkConf: SparkConf) + extends ExecutorPodFactory { + + import ExecutorPodFactoryImpl._ + + private val executorExtraClasspath = sparkConf.get( + org.apache.spark.internal.config.EXECUTOR_CLASS_PATH) + private val executorJarsDownloadDir = sparkConf.get(INIT_CONTAINER_JARS_DOWNLOAD_LOCATION) + + private val executorLabels = ConfigurationUtils.combinePrefixedKeyValuePairsWithDeprecatedConf( + sparkConf, + KUBERNETES_EXECUTOR_LABEL_PREFIX, + KUBERNETES_EXECUTOR_LABELS, + "executor label") + require( + !executorLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + s" Spark.") + + private val executorAnnotations = + ConfigurationUtils.combinePrefixedKeyValuePairsWithDeprecatedConf( + sparkConf, + KUBERNETES_EXECUTOR_ANNOTATION_PREFIX, + KUBERNETES_EXECUTOR_ANNOTATIONS, + "executor annotation") + private val nodeSelector = + ConfigurationUtils.parsePrefixedKeyValuePairs( + sparkConf, + KUBERNETES_NODE_SELECTOR_PREFIX, + "node selector") + + private val executorDockerImage = sparkConf.get(EXECUTOR_DOCKER_IMAGE) + private val dockerImagePullPolicy = sparkConf.get(DOCKER_IMAGE_PULL_POLICY) + private val executorPort = sparkConf.getInt("spark.executor.port", DEFAULT_STATIC_PORT) + private val blockmanagerPort = sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + private val kubernetesDriverPodName = sparkConf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(throw new SparkException("Must specify the driver pod name")) + + private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) + + private val executorMemoryMiB = sparkConf.get(org.apache.spark.internal.config.EXECUTOR_MEMORY) + private val executorMemoryString = sparkConf.get( + org.apache.spark.internal.config.EXECUTOR_MEMORY.key, + org.apache.spark.internal.config.EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = sparkConf + .get(KUBERNETES_EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = sparkConf.getDouble("spark.executor.cores", 1d) + private val executorLimitCores = sparkConf.getOption(KUBERNETES_EXECUTOR_LIMIT_CORES.key) + + override def createExecutorPod( + executorId: String, + applicationId: String, + driverUrl: String, + executorEnvs: Seq[(String, String)], + driverPod: Pod, + nodeToLocalTaskCount: Map[String, Int]): Pod = { + val name = s"$executorPodNamePrefix-exec-$executorId" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId and applicationId + val hostname = name.substring(Math.max(0, name.length - 63)) + val resolvedExecutorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> applicationId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorLabels + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryMiB}Mi") + .build() + val executorMemoryLimitQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCores.toString) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_EXECUTOR_EXTRA_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = sparkConf + .get(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_EXECUTOR_PORT, executorPort.toString), + (ENV_DRIVER_URL, driverUrl), + // Executor backend expects integral value for executor cores, so round it up to an int. + (ENV_EXECUTOR_CORES, math.ceil(executorCores).toInt.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, applicationId), + (ENV_EXECUTOR_ID, executorId), + (ENV_MOUNTED_CLASSPATH, s"$executorJarsDownloadDir/*")) ++ executorEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (EXECUTOR_PORT_NAME, executorPort), + (BLOCK_MANAGER_PORT_NAME, blockmanagerPort)) + .map(port => { + new ContainerPortBuilder() + .withName(port._1) + .withContainerPort(port._2) + .build() + }) + + val executorContainer = new ContainerBuilder() + .withName(s"executor") + .withImage(executorDockerImage) + .withImagePullPolicy(dockerImagePullPolicy) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryLimitQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .build() + + val executorPod = new PodBuilder() + .withNewMetadata() + .withName(name) + .withLabels(resolvedExecutorLabels.asJava) + .withAnnotations(executorAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .withNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(nodeSelector.asJava) + .endSpec() + .build() + + val containerWithExecutorLimitCores = executorLimitCores.map { + limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + + new PodBuilder(executorPod) + .editSpec() + .addToContainers(containerWithExecutorLimitCores) + .endSpec() + .build() + } +} + +private object ExecutorPodFactoryImpl { + private val DEFAULT_STATIC_PORT = 10000 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala new file mode 100644 index 0000000000000..6666c0a091f87 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import java.io.File + +import io.fabric8.kubernetes.client.Config + +import org.apache.spark.SparkContext +import org.apache.spark.deploy.kubernetes.{ConfigurationUtils, SparkKubernetesClientFactory} +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.util.{ThreadUtils, Utils} + +private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { + + override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s") + + override def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler = { + new TaskSchedulerImpl(sc) + } + + override def createSchedulerBackend(sc: SparkContext, masterURL: String, scheduler: TaskScheduler) + : SchedulerBackend = { + val sparkConf = sc.getConf + + val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( + KUBERNETES_MASTER_INTERNAL_URL, + Some(sparkConf.get(KUBERNETES_NAMESPACE)), + APISERVER_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + sparkConf, + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + + val executorPodFactory = new ExecutorPodFactoryImpl(sparkConf) + val allocatorExecutor = ThreadUtils + .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") + val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( + "kubernetes-executor-requests") + new KubernetesClusterSchedulerBackend( + scheduler.asInstanceOf[TaskSchedulerImpl], + sc.env.rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) + } + + override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { + scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala new file mode 100644 index 0000000000000..35589361d046a --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import java.io.Closeable +import java.net.InetAddress +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} + +import scala.collection.{concurrent, mutable} +import scala.collection.JavaConverters._ +import scala.concurrent.{ExecutionContext, Future} + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action + +import org.apache.spark.SparkException +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} +import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +private[spark] class KubernetesClusterSchedulerBackend( + scheduler: TaskSchedulerImpl, + rpcEnv: RpcEnv, + executorPodFactory: ExecutorPodFactory, + kubernetesClient: KubernetesClient, + allocatorExecutor: ScheduledExecutorService, + requestExecutorsService: ExecutorService) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { + + import KubernetesClusterSchedulerBackend._ + + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) + private val RUNNING_EXECUTOR_PODS_LOCK = new Object + // Indexed by executor IDs and guarded by RUNNING_EXECUTOR_PODS_LOCK. + private val runningExecutorsToPods = new mutable.HashMap[String, Pod] + // Indexed by executor pod names and guarded by RUNNING_EXECUTOR_PODS_LOCK. + private val runningPodsToExecutors = new mutable.HashMap[String, String] + // TODO(varun): Get rid of this lock object by my making the underlying map a concurrent hash map. + private val EXECUTOR_PODS_BY_IPS_LOCK = new Object + // Indexed by executor IP addrs and guarded by EXECUTOR_PODS_BY_IPS_LOCK + private val executorPodsByIPs = new mutable.HashMap[String, Pod] + private val podsWithKnownExitReasons: concurrent.Map[String, ExecutorExited] = + new ConcurrentHashMap[String, ExecutorExited]().asScala + private val disconnectedPodsByExecutorIdPendingRemoval = + new ConcurrentHashMap[String, Pod]().asScala + + private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) + + private val kubernetesDriverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse( + throw new SparkException("Must specify the driver pod name")) + private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( + requestExecutorsService) + + private val driverPod = try { + kubernetesClient.pods().inNamespace(kubernetesNamespace). + withName(kubernetesDriverPodName).get() + } catch { + case throwable: Throwable => + logError(s"Executor cannot find driver pod.", throwable) + throw new SparkException(s"Executor cannot find driver pod", throwable) + } + + override val minRegisteredRatio = + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + 0.8 + } else { + super.minRegisteredRatio + } + + private val executorWatchResource = new AtomicReference[Closeable] + protected var totalExpectedExecutors = new AtomicInteger(0) + + private val driverUrl = RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + + private val initialExecutors = getInitialTargetExecutorNumber() + + private val podAllocationInterval = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY) + require(podAllocationInterval > 0, s"Allocation batch delay " + + s"${KUBERNETES_ALLOCATION_BATCH_DELAY} " + + s"is ${podAllocationInterval}, should be a positive integer") + + private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE) + require(podAllocationSize > 0, s"Allocation batch size " + + s"${KUBERNETES_ALLOCATION_BATCH_SIZE} " + + s"is ${podAllocationSize}, should be a positive integer") + + private val allocatorRunnable = new Runnable { + + // Maintains a map of executor id to count of checks performed to learn the loss reason + // for an executor. + private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] + + override def run(): Unit = { + handleDisconnectedExecutors() + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + if (totalRegisteredExecutors.get() < runningExecutorsToPods.size) { + logDebug("Waiting for pending executors before scaling") + } else if (totalExpectedExecutors.get() <= runningExecutorsToPods.size) { + logDebug("Maximum allowed executor limit reached. Not scaling up further.") + } else { + val nodeToLocalTaskCount = getNodesWithLocalTaskCounts + for (i <- 0 until math.min( + totalExpectedExecutors.get - runningExecutorsToPods.size, podAllocationSize)) { + val (executorId, pod) = allocateNewExecutorPod(nodeToLocalTaskCount) + runningExecutorsToPods.put(executorId, pod) + runningPodsToExecutors.put(pod.getMetadata.getName, executorId) + logInfo( + s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") + } + } + } + } + + def handleDisconnectedExecutors(): Unit = { + // For each disconnected executor, synchronize with the loss reasons that may have been found + // by the executor pod watcher. If the loss reason was discovered by the watcher, + // inform the parent class with removeExecutor. + val disconnectedPodsByExecutorIdPendingRemovalCopy = + Map.empty ++ disconnectedPodsByExecutorIdPendingRemoval + disconnectedPodsByExecutorIdPendingRemovalCopy.foreach { case (executorId, executorPod) => + val knownExitReason = podsWithKnownExitReasons.remove(executorPod.getMetadata.getName) + knownExitReason.fold { + removeExecutorOrIncrementLossReasonCheckCount(executorId) + } { executorExited => + logDebug(s"Removing executor $executorId with loss reason " + executorExited.message) + removeExecutor(executorId, executorExited) + // We keep around executors that have exit conditions caused by the application. This + // allows them to be debugged later on. Otherwise, mark them as to be deleted from the + // the API server. + if (!executorExited.exitCausedByApp) { + deleteExecutorFromClusterAndDataStructures(executorId) + } + } + } + } + + def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { + val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) + if (reasonCheckCount >= MAX_EXECUTOR_LOST_REASON_CHECKS) { + removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) + deleteExecutorFromClusterAndDataStructures(executorId) + } else { + executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) + } + } + + def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { + disconnectedPodsByExecutorIdPendingRemoval -= executorId + executorReasonCheckAttemptCounts -= executorId + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.remove(executorId).map { pod => + kubernetesClient.pods().delete(pod) + runningPodsToExecutors.remove(pod.getMetadata.getName) + }.getOrElse(logWarning(s"Unable to remove pod for unknown executor $executorId")) + } + } + } + + private def getInitialTargetExecutorNumber(defaultNumExecutors: Int = 1): Int = { + if (Utils.isDynamicAllocationEnabled(conf)) { + val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) + val initialNumExecutors = Utils.getDynamicAllocationInitialExecutors(conf) + val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", 1) + require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, + s"initial executor number $initialNumExecutors must between min executor number " + + s"$minNumExecutors and max executor number $maxNumExecutors") + + initialNumExecutors + } else { + conf.getInt("spark.executor.instances", defaultNumExecutors) + } + + } + + override def applicationId(): String = conf.get("spark.app.id", super.applicationId()) + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= initialExecutors * minRegisteredRatio + } + + override def start(): Unit = { + super.start() + executorWatchResource.set( + kubernetesClient + .pods() + .withLabel(SPARK_APP_ID_LABEL, applicationId()) + .watch(new ExecutorPodsWatcher())) + + allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable, 0L, podAllocationInterval, TimeUnit.SECONDS) + + if (!Utils.isDynamicAllocationEnabled(conf)) { + doRequestTotalExecutors(initialExecutors) + } + } + + override def stop(): Unit = { + // stop allocation of new resources and caches. + allocatorExecutor.shutdown() + + // send stop message to executors so they shut down cleanly + super.stop() + + // then delete the executor pods + // TODO investigate why Utils.tryLogNonFatalError() doesn't work in this context. + // When using Utils.tryLogNonFatalError some of the code fails but without any logs or + // indication as to why. + try { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.values.foreach(kubernetesClient.pods().delete(_)) + runningExecutorsToPods.clear() + runningPodsToExecutors.clear() + } + EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs.clear() + } + val resource = executorWatchResource.getAndSet(null) + if (resource != null) { + resource.close() + } + } catch { + case e: Throwable => logError("Uncaught exception while shutting down controllers.", e) + } + try { + logInfo("Closing kubernetes client") + kubernetesClient.close() + } catch { + case e: Throwable => logError("Uncaught exception closing Kubernetes client.", e) + } + } + + /** + * @return A map of K8s cluster nodes to the number of tasks that could benefit from data + * locality if an executor launches on the cluster node. + */ + private def getNodesWithLocalTaskCounts() : Map[String, Int] = { + val executorPodsWithIPs = EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs.values.toList // toList makes a defensive copy. + } + val nodeToLocalTaskCount = mutable.Map[String, Int]() ++ + KubernetesClusterSchedulerBackend.this.synchronized { + hostToLocalTaskCount + } + for (pod <- executorPodsWithIPs) { + // Remove cluster nodes that are running our executors already. + // TODO: This prefers spreading out executors across nodes. In case users want + // consolidating executors on fewer nodes, introduce a flag. See the spark.deploy.spreadOut + // flag that Spark standalone has: https://spark.apache.org/docs/latest/spark-standalone.html + nodeToLocalTaskCount.remove(pod.getSpec.getNodeName).nonEmpty || + nodeToLocalTaskCount.remove(pod.getStatus.getHostIP).nonEmpty || + nodeToLocalTaskCount.remove( + InetAddress.getByName(pod.getStatus.getHostIP).getCanonicalHostName).nonEmpty + } + nodeToLocalTaskCount.toMap[String, Int] + } + + /** + * Allocates a new executor pod + * + * @param nodeToLocalTaskCount A map of K8s cluster nodes to the number of tasks that could + * benefit from data locality if an executor launches on the cluster + * node. + * @return A tuple of the new executor name and the Pod data structure. + */ + private def allocateNewExecutorPod(nodeToLocalTaskCount: Map[String, Int]): (String, Pod) = { + val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString + val executorPod = executorPodFactory.createExecutorPod( + executorId, + applicationId(), + driverUrl, + conf.getExecutorEnv, + driverPod, + nodeToLocalTaskCount) + try { + (executorId, kubernetesClient.pods.create(executorPod)) + } catch { + case throwable: Throwable => + logError("Failed to allocate executor pod.", throwable) + throw throwable + } + } + + override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future[Boolean] { + totalExpectedExecutors.set(requestedTotal) + true + } + + override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + for (executor <- executorIds) { + val maybeRemovedExecutor = runningExecutorsToPods.remove(executor) + maybeRemovedExecutor.foreach { executorPod => + kubernetesClient.pods().delete(executorPod) + disconnectedPodsByExecutorIdPendingRemoval(executor) = executorPod + runningPodsToExecutors.remove(executorPod.getMetadata.getName) + } + if (maybeRemovedExecutor.isEmpty) { + logWarning(s"Unable to remove pod for unknown executor $executor") + } + } + } + true + } + + def getExecutorPodByIP(podIP: String): Option[Pod] = { + EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs.get(podIP) + } + } + + private class ExecutorPodsWatcher extends Watcher[Pod] { + + private val DEFAULT_CONTAINER_FAILURE_EXIT_STATUS = -1 + + override def eventReceived(action: Action, pod: Pod): Unit = { + if (action == Action.MODIFIED && pod.getStatus.getPhase == "Running" + && pod.getMetadata.getDeletionTimestamp == null) { + val podIP = pod.getStatus.getPodIP + val clusterNodeName = pod.getSpec.getNodeName + logDebug(s"Executor pod $pod ready, launched at $clusterNodeName as IP $podIP.") + EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs += ((podIP, pod)) + } + } else if ((action == Action.MODIFIED && pod.getMetadata.getDeletionTimestamp != null) || + action == Action.DELETED || action == Action.ERROR) { + val podName = pod.getMetadata.getName + val podIP = pod.getStatus.getPodIP + logDebug(s"Executor pod $podName at IP $podIP was at $action.") + if (podIP != null) { + EXECUTOR_PODS_BY_IPS_LOCK.synchronized { + executorPodsByIPs -= podIP + } + } + if (action == Action.ERROR) { + logInfo(s"Received pod $podName exited event. Reason: " + pod.getStatus.getReason) + handleErroredPod(pod) + } else if (action == Action.DELETED) { + logInfo(s"Received delete pod $podName event. Reason: " + pod.getStatus.getReason) + handleDeletedPod(pod) + } + } + } + + override def onClose(cause: KubernetesClientException): Unit = { + logDebug("Executor pod watch closed.", cause) + } + + def getExecutorExitStatus(pod: Pod): Int = { + val containerStatuses = pod.getStatus.getContainerStatuses + if (!containerStatuses.isEmpty) { + // we assume the first container represents the pod status. This assumption may not hold + // true in the future. Revisit this if side-car containers start running inside executor + // pods. + getExecutorExitStatus(containerStatuses.get(0)) + } else DEFAULT_CONTAINER_FAILURE_EXIT_STATUS + } + + def getExecutorExitStatus(containerStatus: ContainerStatus): Int = { + Option(containerStatus.getState).map(containerState => + Option(containerState.getTerminated).map(containerStateTerminated => + containerStateTerminated.getExitCode.intValue()).getOrElse(UNKNOWN_EXIT_CODE) + ).getOrElse(UNKNOWN_EXIT_CODE) + } + + def isPodAlreadyReleased(pod: Pod): Boolean = { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + !runningPodsToExecutors.contains(pod.getMetadata.getName) + } + } + + def handleErroredPod(pod: Pod): Unit = { + val containerExitStatus = getExecutorExitStatus(pod) + // container was probably actively killed by the driver. + val exitReason = if (isPodAlreadyReleased(pod)) { + ExecutorExited(containerExitStatus, exitCausedByApp = false, + s"Container in pod " + pod.getMetadata.getName + + " exited from explicit termination request.") + } else { + val containerExitReason = containerExitStatus match { + case VMEM_EXCEEDED_EXIT_CODE | PMEM_EXCEEDED_EXIT_CODE => + memLimitExceededLogMessage(pod.getStatus.getReason) + case _ => + // Here we can't be sure that that exit was caused by the application but this seems + // to be the right default since we know the pod was not explicitly deleted by + // the user. + s"Pod ${pod.getMetadata.getName}'s executor container exited with exit status" + + s" code $containerExitStatus." + } + ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) + } + podsWithKnownExitReasons.put(pod.getMetadata.getName, exitReason) + } + + def handleDeletedPod(pod: Pod): Unit = { + val exitMessage = if (isPodAlreadyReleased(pod)) { + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." + } else { + s"Pod ${pod.getMetadata.getName} deleted or lost." + } + val exitReason = ExecutorExited( + getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) + podsWithKnownExitReasons.put(pod.getMetadata.getName, exitReason) + } + } + + override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = { + new KubernetesDriverEndpoint(rpcEnv, properties) + } + + private class KubernetesDriverEndpoint( + rpcEnv: RpcEnv, + sparkProperties: Seq[(String, String)]) + extends DriverEndpoint(rpcEnv, sparkProperties) { + + override def onDisconnected(rpcAddress: RpcAddress): Unit = { + addressToExecutorId.get(rpcAddress).foreach { executorId => + if (disableExecutor(executorId)) { + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.get(executorId).foreach { pod => + disconnectedPodsByExecutorIdPendingRemoval.put(executorId, pod) + } + } + } + } + } + } +} + +private object KubernetesClusterSchedulerBackend { + private val VMEM_EXCEEDED_EXIT_CODE = -103 + private val PMEM_EXCEEDED_EXIT_CODE = -104 + private val UNKNOWN_EXIT_CODE = -111 + // Number of times we are allowed check for the loss reason for an executor before we give up + // and assume the executor failed for good, and attribute it to a framework fault. + val MAX_EXECUTOR_LOST_REASON_CHECKS = 10 + + def memLimitExceededLogMessage(diagnostics: String): String = { + s"Pod/Container killed for exceeding memory limits. $diagnostics" + + " Consider boosting spark executor memory overhead." + } +} + diff --git a/resource-managers/kubernetes/core/src/test/resources/log4j.properties b/resource-managers/kubernetes/core/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..ad95fadb7c0c0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/resources/log4j.properties @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=INFO, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from a few verbose libraries. +log4j.logger.com.sun.jersey=WARN +log4j.logger.org.apache.hadoop=WARN +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.mortbay=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactorySuite.scala new file mode 100644 index 0000000000000..515ec413f4e31 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/ExecutorPodFactorySuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import io.fabric8.kubernetes.client.KubernetesClient +import org.mockito.MockitoAnnotations +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants +import org.apache.spark.network.netty.SparkTransportConf + +class ExecutorPodFactoryImplSuite extends SparkFunSuite with BeforeAndAfter { + private val driverPodName: String = "driver-pod" + private val driverPodUid: String = "driver-uid" + private val driverUrl: String = "driver-url" + private val executorPrefix: String = "base" + private val executorImage: String = "executor-image" + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(driverPodName) + .withUid(driverPodUid) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + private var sc: SparkContext = _ + + before { + SparkContext.clearActiveContext() + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) + .set(EXECUTOR_DOCKER_IMAGE, executorImage) + sc = new SparkContext("local", "test") + } + private var kubernetesClient: KubernetesClient = _ + + test("basic executor pod has reasonable defaults") { + val factory = new ExecutorPodFactoryImpl(baseConf) + val executor = factory.createExecutorPod("1", "dummy", "dummy", + Seq[(String, String)](), driverPod, Map[String, Int]()) + + // The executor pod name and default labels. + assert(executor.getMetadata.getName == s"$executorPrefix-exec-1") + assert(executor.getMetadata.getLabels.size() == 3) + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.getSpec.getContainers.size() == 1) + assert(executor.getSpec.getContainers.get(0).getImage == executorImage) + assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() == 0) + assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() == 1) + assert(executor.getSpec.getContainers.get(0).getResources. + getLimits.get("memory").getAmount == "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.getSpec.getNodeSelector.size() == 0) + assert(executor.getSpec.getVolumes.size() == 0) + + checkEnv(executor, Set()) + checkOwnerReferences(executor, driverPodUid) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, + "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod("1", + "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + + assert(executor.getSpec.getHostname.length == 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val factory = new ExecutorPodFactoryImpl(conf) + val executor = factory.createExecutorPod("1", + "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) + + checkEnv(executor, Set("SPARK_JAVA_OPT_0", "SPARK_EXECUTOR_EXTRA_CLASSPATH", "qux")) + checkOwnerReferences(executor, driverPodUid) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() == 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid == driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController == true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executor: Pod, additionalEnvVars: Set[String]): Unit = { + val defaultEnvs = Set(constants.ENV_EXECUTOR_ID, + constants.ENV_DRIVER_URL, constants.ENV_EXECUTOR_CORES, + constants.ENV_EXECUTOR_MEMORY, constants.ENV_APPLICATION_ID, + constants.ENV_MOUNTED_CLASSPATH, constants.ENV_EXECUTOR_POD_IP, + constants.ENV_EXECUTOR_PORT) ++ additionalEnvVars + + assert(executor.getSpec.getContainers.size() == 1) + assert(executor.getSpec.getContainers.get(0).getEnv().size() == defaultEnvs.size) + val setEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { + x => x.getName + }.toSet + assert(defaultEnvs == setEnvs) + } +} \ No newline at end of file diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..25e590ee35f0c --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} + +import scala.collection.JavaConverters._ +import scala.concurrent.Future + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{any, eq => mockitoEq} +import org.mockito.Mockito.{mock => _, _} +import org.scalatest.BeforeAndAfter +import org.scalatest.mock.MockitoSugar._ + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.rpc._ +import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +private[spark] class KubernetesClusterSchedulerBackendSuite + extends SparkFunSuite with BeforeAndAfter { + + private val APP_ID = "test-spark-app" + private val DRIVER_POD_NAME = "spark-driver-pod" + private val NAMESPACE = "test-namespace" + private val SPARK_DRIVER_HOST = "localhost" + private val SPARK_DRIVER_PORT = 7077 + private val POD_ALLOCATION_INTERVAL = 60L + private val DRIVER_URL = RpcEndpointAddress( + SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val FIRST_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod1") + .endMetadata() + .withNewSpec() + .withNodeName("node1") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private val SECOND_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod2") + .endMetadata() + .withNewSpec() + .withNodeName("node2") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.101") + .endStatus() + .build() + + private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + private type LABELLED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + private type IN_NAMESPACE_PODS = NonNamespaceOperation[ + Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + + @Mock + private var sparkContext: SparkContext = _ + + @Mock + private var listenerBus: LiveListenerBus = _ + + @Mock + private var taskSchedulerImpl: TaskSchedulerImpl = _ + + @Mock + private var allocatorExecutor: ScheduledExecutorService = _ + + @Mock + private var requestExecutorsService: ExecutorService = _ + + @Mock + private var executorPodFactory: ExecutorPodFactory = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var podsWithLabelOperations: LABELLED_PODS = _ + + @Mock + private var podsInNamespace: IN_NAMESPACE_PODS = _ + + @Mock + private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + + @Mock + private var rpcEnv: RpcEnv = _ + + @Mock + private var driverEndpointRef: RpcEndpointRef = _ + + @Mock + private var executorPodsWatch: Watch = _ + + private var sparkConf: SparkConf = _ + private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ + private var allocatorRunnable: ArgumentCaptor[Runnable] = _ + private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ + private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .addToLabels(SPARK_APP_ID_LABEL, APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .endMetadata() + .build() + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf() + .set("spark.app.id", APP_ID) + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_NAMESPACE, NAMESPACE) + .set("spark.driver.host", SPARK_DRIVER_HOST) + .set("spark.driver.port", SPARK_DRIVER_PORT.toString) + .set(KUBERNETES_ALLOCATION_BATCH_DELAY, POD_ALLOCATION_INTERVAL) + executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) + when(sparkContext.conf).thenReturn(sparkConf) + when(sparkContext.listenerBus).thenReturn(listenerBus) + when(taskSchedulerImpl.sc).thenReturn(sparkContext) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) + when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) + .thenReturn(executorPodsWatch) + when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) + when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) + when(podsWithDriverName.get()).thenReturn(driverPod) + when(allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable.capture(), + mockitoEq(0L), + mockitoEq(POD_ALLOCATION_INTERVAL), + mockitoEq(TimeUnit.SECONDS))).thenReturn(null) + // Creating Futures in Scala backed by a Java executor service resolves to running + // ExecutorService#execute (as opposed to submit) + doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) + when(rpcEnv.setupEndpoint( + mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) + .thenReturn(driverEndpointRef) + when(driverEndpointRef.ask[Boolean] + (any(classOf[Any])) + (any())).thenReturn(mock[Future[Boolean]]) + } + + test("Basic lifecycle expectations when starting and stopping the scheduler.") { + val scheduler = newSchedulerBackend() + scheduler.start() + assert(executorPodsWatcherArgument.getValue != null) + assert(allocatorRunnable.getValue != null) + scheduler.stop() + verify(executorPodsWatch).close() + } + + test("Static allocation should request executors upon first allocator run.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + verify(podOperations).create(FIRST_EXECUTOR_POD) + verify(podOperations).create(SECOND_EXECUTOR_POD) + } + + test("Killing executors deletes the executor pods") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + scheduler.doKillExecutors(Seq("2")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations).delete(SECOND_EXECUTOR_POD) + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + test("Executors should be requested in batches.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(FIRST_EXECUTOR_POD) + verify(podOperations, never()).create(SECOND_EXECUTOR_POD) + val registerFirstExecutorMessage = RegisterExecutor( + "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + allocatorRunnable.getValue.run() + verify(podOperations).create(SECOND_EXECUTOR_POD) + } + + test("Deleting executors and then running an allocator pass after finding the loss reason" + + " should only delete the pod once.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + scheduler.doRequestTotalExecutors(0) + requestExecutorRunnable.getAllValues.asScala.last.run() + scheduler.doKillExecutors(Seq("1")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations, times(1)).delete(FIRST_EXECUTOR_POD) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + + val exitedPod = exitPod(FIRST_EXECUTOR_POD, 0) + executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).delete(FIRST_EXECUTOR_POD) + verify(driverEndpointRef, times(1)).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 0, + exitCausedByApp = false, + s"Container in pod ${exitedPod.getMetadata.getName} exited from" + + s" explicit termination request."))) + } + + test("Executors that disconnect from application errors are noted as exits caused by app.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + executorPodsWatcherArgument.getValue.eventReceived( + Action.ERROR, exitPod(FIRST_EXECUTOR_POD, 1)) + + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + scheduler.doRequestTotalExecutors(1) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getAllValues.asScala.last.run() + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 1, + exitCausedByApp = true, + s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + + " exit status code 1."))) + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + test("Executors should only try to get the loss reason a number of times before giving up and" + + " removing the executor.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend() + scheduler.start() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + 1 to KubernetesClusterSchedulerBackend.MAX_EXECUTOR_LOST_REASON_CHECKS foreach { _ => + allocatorRunnable.getValue.run() + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).delete(FIRST_EXECUTOR_POD) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) + } + + private def newSchedulerBackend(): KubernetesClusterSchedulerBackend = { + new KubernetesClusterSchedulerBackend( + taskSchedulerImpl, + rpcEnv, + executorPodFactory, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) + } + + private def exitPod(basePod: Pod, exitCode: Int): Pod = { + new PodBuilder(FIRST_EXECUTOR_POD) + .editStatus() + .addNewContainerStatus() + .withNewState() + .withNewTerminated() + .withExitCode(exitCode) + .endTerminated() + .endState() + .endContainerStatus() + .endStatus() + .build() + } + + private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Unit = { + when(executorPodFactory.createExecutorPod( + executorId.toString, + APP_ID, + DRIVER_URL, + sparkConf.getExecutorEnv, + driverPod, + Map.empty)).thenReturn(expectedPod) + } + +} diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 33bc79a92b9e7..d0a54288780ea 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -580,6 +580,8 @@ primaryExpression | '(' query ')' #subqueryExpression | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall + | qualifiedName '(' trimOption=(BOTH | LEADING | TRAILING) argument+=expression + FROM argument+=expression ')' #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference @@ -748,6 +750,7 @@ nonReserved | UNBOUNDED | WHEN | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP | DIRECTORY + | BOTH | LEADING | TRAILING ; SELECT: 'SELECT'; @@ -861,6 +864,9 @@ COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; IGNORE: 'IGNORE'; +BOTH: 'BOTH'; +LEADING: 'LEADING'; +TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java index 4776617043878..5d9515c0725da 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java @@ -109,22 +109,6 @@ public void setOffsetAndSize(int ordinal, long currentCursor, long size) { Platform.putLong(holder.buffer, fieldOffset, offsetAndSize); } - // Do word alignment for this row and grow the row buffer if needed. - // todo: remove this after we make unsafe array data word align. - public void alignToWords(int numBytes) { - final int remainder = numBytes & 0x07; - - if (remainder > 0) { - final int paddingBytes = 8 - remainder; - holder.grow(paddingBytes); - - for (int i = 0; i < paddingBytes; i++) { - Platform.putByte(holder.buffer, holder.cursor, (byte) 0); - holder.cursor++; - } - } - } - public void write(int ordinal, boolean value) { final long offset = getFieldOffset(ordinal); Platform.putLong(holder.buffer, offset, 0L); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0880bd66ea4c4..db276fbc9d53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2256,7 +2256,10 @@ object CleanupAliases extends Rule[LogicalPlan] { def trimNonTopLevelAliases(e: Expression): Expression = e match { case a: Alias => - a.withNewChildren(trimAliases(a.child) :: Nil) + a.copy(child = trimAliases(a.child))( + exprId = a.exprId, + qualifier = a.qualifier, + explicitMetadata = Some(a.metadata)) case other => trimAliases(other) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 0908d68d25649..9407b727bca4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -377,6 +377,8 @@ class SessionCatalog( requireDbExists(db) requireTableExists(tableIdentifier) externalCatalog.alterTableStats(db, table, newStats) + // Invalidate the table relation cache + refreshTable(identifier) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 134163187b7c6..18b4fed597447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -618,13 +618,13 @@ case class JsonToStructs( {"time":"26/08/2015"} > SELECT _FUNC_(array(named_struct('a', 1, 'b', 2)); [{"a":1,"b":2}] - > SELECT _FUNC_(map('a',named_struct('b',1))); + > SELECT _FUNC_(map('a', named_struct('b', 1))); {"a":{"b":1}} - > SELECT _FUNC_(map(named_struct('a',1),named_struct('b',2))); + > SELECT _FUNC_(map(named_struct('a', 1),named_struct('b', 2))); {"[1]":{"b":2}} - > SELECT _FUNC_(map('a',1)); + > SELECT _FUNC_(map('a', 1)); {"a":1} - > SELECT _FUNC_(array((map('a',1)))); + > SELECT _FUNC_(array((map('a', 1)))); [{"a":1}] """, since = "2.2.0") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7ab45a6ee8737..83de515079eea 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -24,6 +24,7 @@ import java.util.regex.Pattern import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -503,69 +504,304 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override def prettyName: String = "find_in_set" } +trait String2TrimExpression extends Expression with ImplicitCastInputTypes { + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) +} + +object StringTrim { + def apply(str: Expression, trimStr: Expression) : StringTrim = StringTrim(str, Some(trimStr)) + def apply(str: Expression) : StringTrim = StringTrim(str, None) +} + /** - * A function that trim the spaces from both ends for the specified string. + * A function that takes a character string, removes the leading and trailing characters matching with any character + * in the trim string, returns the new string. + * If BOTH and trimStr keywords are not specified, it defaults to remove space character from both ends. The trim + * function will have one argument, which contains the source string. + * If BOTH and trimStr keywords are specified, it trims the characters from both ends, and the trim function will have + * two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: A character string to be trimmed from the source string, if it has multiple characters, the function + * searches for each character in the source string, removes the characters from the source string until it + * encounters the first non-match character. + * BOTH: removes any character from both ends of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading and trailing space characters from `str`. + _FUNC_(BOTH trimStr FROM str) - Remove the leading and trailing trimString from `str` + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_(BOTH 'SL' FROM 'SSparkSQLS'); + parkSQ """) -case class StringTrim(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrim( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { + + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) - def convert(v: UTF8String): UTF8String = v.trim() + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "trim" + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else { + if (trimStr.isDefined) { + srcString.trim(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + srcString.trim() + } + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trim()") + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(); + }""") + } else { + val trimString = evals(1) + val getTrimFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trim(${trimString.value}); + }""" + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimFunction + }""") + } } } +object StringTrimLeft { + def apply(str: Expression, trimStr: Expression) : StringTrimLeft = StringTrimLeft(str, Some(trimStr)) + def apply(str: Expression) : StringTrimLeft = StringTrimLeft(str, None) +} + /** - * A function that trim the spaces from left end for given string. + * A function that trims the characters from left end for a given string. + * If LEADING and trimStr keywords are not specified, it defaults to remove space character from the left end. The ltrim + * function will have one argument, which contains the source string. + * If LEADING and trimStr keywords are not specified, it trims the characters from left end. The ltrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any character from the left end of the source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * LEADING: removes any character from the left end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the leading and trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the leading space characters from `str`. + _FUNC_(trimStr, str) - Removes the leading string contains the characters from the trim string + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: - > SELECT _FUNC_(' SparkSQL'); + > SELECT _FUNC_(' SparkSQL '); SparkSQL + > SELECT _FUNC_('Sp', 'SSparkSQLS'); + arkSQLS """) -case class StringTrimLeft(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimLeft( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { - def convert(v: UTF8String): UTF8String = v.trimLeft() + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "ltrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimLeft()") + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil } + + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else { + if (trimStr.isDefined) { + srcString.trimLeft(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + srcString.trimLeft() + } + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(); + }""") + } else { + val trimString = evals(1) + val getTrimLeftFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); + }""" + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimLeftFunction + }""") + } + } +} + +object StringTrimRight { + def apply(str: Expression, trimStr: Expression) : StringTrimRight = StringTrimRight(str, Some(trimStr)) + def apply(str: Expression) : StringTrimRight = StringTrimRight(str, None) } /** - * A function that trim the spaces from right end for given string. + * A function that trims the characters from right end for a given string. + * If TRAILING and trimStr keywords are not specified, it defaults to remove space character from the right end. The + * rtrim function will have one argument, which contains the source string. + * If TRAILING and trimStr keywords are specified, it trims the characters from right end. The rtrim function will + * have two arguments, the first argument contains trimStr, the second argument contains the source string. + * trimStr: the function removes any character from the right end of source string which matches with the characters + * from trimStr, it stops at the first non-match character. + * TRAILING: removes any character from the right end of the source string that matches characters in the trim string. */ @ExpressionDescription( - usage = "_FUNC_(str) - Removes the trailing space characters from `str`.", + usage = """ + _FUNC_(str) - Removes the trailing space characters from `str`. + _FUNC_(trimStr, str) - Removes the trailing string which contains the characters from the trim string from the `str` + """, + arguments = """ + Arguments: + * str - a string expression + * trimString - the trim string + * BOTH, FROM - these are keyword to specify for trim string from both ends of the string + """, examples = """ Examples: > SELECT _FUNC_(' SparkSQL '); - SparkSQL + SparkSQL + > SELECT _FUNC_('LQSa', 'SSparkSQLS'); + SSpark """) -case class StringTrimRight(child: Expression) - extends UnaryExpression with String2StringExpression { +case class StringTrimRight( + srcStr: Expression, + trimStr: Option[Expression] = None) + extends String2TrimExpression { - def convert(v: UTF8String): UTF8String = v.trimRight() + def this(trimStr: Expression, srcStr: Expression) = this(srcStr, Option(trimStr)) + + def this(srcStr: Expression) = this(srcStr, None) override def prettyName: String = "rtrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, c => s"($c).trimRight()") + override def children: Seq[Expression] = if (trimStr.isDefined) { + srcStr :: trimStr.get :: Nil + } else { + srcStr :: Nil + } + + override def eval(input: InternalRow): Any = { + val srcString = srcStr.eval(input).asInstanceOf[UTF8String] + if (srcString == null) { + null + } else { + if (trimStr.isDefined) { + srcString.trimRight(trimStr.get.eval(input).asInstanceOf[UTF8String]) + } else { + srcString.trimRight() + } + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) + val srcString = evals(0) + + if (evals.length == 1) { + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimRight(); + }""") + } else { + val trimString = evals(1) + val getTrimRightFunction = + s""" + if (${trimString.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); + }""" + ev.copy(evals.map(_.code).mkString + s""" + boolean ${ev.isNull} = false; + UTF8String ${ev.value} = null; + if (${srcString.isNull}) { + ${ev.isNull} = true; + } else { + $getTrimRightFunction + }""") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index dfe7e28121943..eb06e4f304f0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -43,7 +43,7 @@ private[sql] class JacksonGenerator( private type ValueWriter = (SpecializedGetters, Int) => Unit // `JackGenerator` can only be initialized with a `StructType` or a `MapType`. - require(dataType.isInstanceOf[StructType] | dataType.isInstanceOf[MapType], + require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType], "JacksonGenerator only supports to be initialized with a StructType " + s"or MapType but got ${dataType.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 891f61698f177..85b492e83446e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1179,6 +1179,26 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create a (windowed) Function expression. */ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + def replaceFunctions( + funcID: FunctionIdentifier, + ctx: FunctionCallContext): FunctionIdentifier = { + val opt = ctx.trimOption + if (opt != null) { + if (ctx.qualifiedName.getText.toLowerCase(Locale.ROOT) != "trim") { + throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + + s"doesn't support with option ${opt.getText}.", ctx) + } + opt.getType match { + case SqlBaseParser.BOTH => funcID + case SqlBaseParser.LEADING => funcID.copy(funcName = "ltrim") + case SqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim") + case _ => throw new ParseException("Function trim doesn't support with " + + s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) + } + } else { + funcID + } + } // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) @@ -1190,7 +1210,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case expressions => expressions } - val function = UnresolvedFunction(visitFunctionName(ctx.qualifiedName), arguments, isDistinct) + val funcId = replaceFunctions(visitFunctionName(ctx.qualifiedName), ctx) + val function = UnresolvedFunction(funcId, arguments, isDistinct) + // Check if the function is evaluated in a windowed context. ctx.windowSpec match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 4f08031153ab0..18ef4bc37c2b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ - class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("concat") { @@ -406,26 +405,78 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:on } - test("TRIM/LTRIM/RTRIM") { + test("TRIM") { val s = 'a.string.at(0) checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) + checkEvaluation(StringTrim("aa", "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrim(Literal(" aabbtrimccc"), "ab cd"), "trim", create_row("bdef")) + checkEvaluation(StringTrim(Literal("a@>.,>"), "a.,@<>"), " ", create_row(" abdef ")) checkEvaluation(StringTrim(s), "abdef", create_row(" abdef ")) + checkEvaluation(StringTrim(s, "abd"), "ef", create_row("abdefa")) + checkEvaluation(StringTrim(s, "a"), "bdef", create_row("aaabdefaaaa")) + checkEvaluation(StringTrim(s, "SLSQ"), "park", create_row("SSparkSQLS")) + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrim(s, "花世界"), "", create_row("花花世界花花")) + checkEvaluation(StringTrim(s, "花 "), "世界", create_row(" 花花世界花花")) + checkEvaluation(StringTrim(s, "花 "), "世界", create_row(" 花 花 世界 花 花 ")) + checkEvaluation(StringTrim(s, "a花世"), "界", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrim(s, "a@#( )"), "花花世界花花", create_row("aa()花花世界花花@ #")) + checkEvaluation(StringTrim(Literal("花trim"), "花 "), "trim", create_row(" abdef ")) + // scalastyle:on + checkEvaluation(StringTrim(Literal("a"), Literal.create(null, StringType)), null) + checkEvaluation(StringTrim(Literal.create(null, StringType), Literal("a")), null) + } + + test("LTRIM") { + val s = 'a.string.at(0) checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aa"), "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aa "), "a "), "", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(Literal("aabbcaaaa"), "ab"), "caaaa", create_row(" abdef ")) checkEvaluation(StringTrimLeft(s), "abdef ", create_row(" abdef ")) + checkEvaluation(StringTrimLeft(s, "a"), "bdefa", create_row("abdefa")) + checkEvaluation(StringTrimLeft(s, "a "), "bdefaaaa", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimLeft(s, "Spk"), "arkSQLS", create_row("SSparkSQLS")) + // scalastyle:off + // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) + checkEvaluation(StringTrimLeft(s, "花"), "世界花花", create_row("花花世界花花")) + checkEvaluation(StringTrimLeft(s, "花 世"), "界花花", create_row(" 花花世界花花")) + checkEvaluation(StringTrimLeft(s, "花"), "a花花世界花花 ", create_row("a花花世界花花 ")) + checkEvaluation(StringTrimLeft(s, "a花界"), "世界花花aa", create_row("aa花花世界花花aa")) + checkEvaluation(StringTrimLeft(s, "a世界"), "花花世界花花", create_row("花花世界花花")) + // scalastyle:on + checkEvaluation(StringTrimLeft(Literal.create(null, StringType), Literal("a")), null) + checkEvaluation(StringTrimLeft(Literal("a"), Literal.create(null, StringType)), null) + } + + test("RTRIM") { + val s = 'a.string.at(0) checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("a"), "a"), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("ab"), "ab"), "", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("aabbaaaa %"), "a %"), "aabb", create_row("def")) checkEvaluation(StringTrimRight(s), " abdef", create_row(" abdef ")) + checkEvaluation(StringTrimRight(s, "a"), "abdef", create_row("abdefa")) + checkEvaluation(StringTrimRight(s, "abf de"), "", create_row(" aaabdefaaaa")) + checkEvaluation(StringTrimRight(s, "S*&"), "SSparkSQL", create_row("SSparkSQLS*")) // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. + checkEvaluation(StringTrimRight(Literal("a"), "花"), "a", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("花"), "a"), "花", create_row(" abdef ")) + checkEvaluation(StringTrimRight(Literal("花花世界"), "界花世"), "", create_row(" abdef ")) checkEvaluation(StringTrimRight(s), " 花花世界", create_row(" 花花世界 ")) - checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) - checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) + checkEvaluation(StringTrimRight(s, "花a#"), "花花世界", create_row("花花世界花花###aa花")) + checkEvaluation(StringTrimRight(s, "花"), "", create_row("花花花花")) + checkEvaluation(StringTrimRight(s, "花 界b@"), " 花花世", create_row(" 花花世 b界@花花 ")) // scalastyle:on - checkEvaluation(StringTrim(Literal.create(null, StringType)), null) - checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null) - checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal("a"), Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal.create(null, StringType), Literal("a")), null) } test("FORMAT") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 587437e9aa81d..e7a5bcee420f5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Rand +import org.apache.spark.sql.catalyst.expressions.{Alias, Rand} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.MetadataBuilder class CollapseProjectSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { @@ -119,4 +120,22 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("preserve top-level alias metadata while collapsing projects") { + def hasMetadata(logicalPlan: LogicalPlan): Boolean = { + logicalPlan.asInstanceOf[Project].projectList.exists(_.metadata.contains("key")) + } + + val metadata = new MetadataBuilder().putLong("key", 1).build() + val analyzed = + Project(Seq(Alias('a_with_metadata, "b")()), + Project(Seq(Alias('a, "a_with_metadata")(explicitMetadata = Some(metadata))), + testRelation.logicalPlan)).analyze + require(hasMetadata(analyzed)) + + val optimized = Optimize.execute(analyzed) + val projects = optimized.collect { case p: Project => p } + assert(projects.size === 1) + assert(hasMetadata(optimized)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index b0d2fb26a6006..306e6f2cfbd37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -651,4 +651,22 @@ class PlanParserSuite extends AnalysisTest { ) ) } + + test("TRIM function") { + intercept("select ltrim(both 'S' from 'SS abc S'", "missing ')' at ''") + intercept("select rtrim(trailing 'S' from 'SS abc S'", "missing ')' at ''") + + assertEqual( + "SELECT TRIM(BOTH '@$%&( )abc' FROM '@ $ % & ()abc ' )", + OneRowRelation().select('TRIM.function("@$%&( )abc", "@ $ % & ()abc ")) + ) + assertEqual( + "SELECT TRIM(LEADING 'c []' FROM '[ ccccbcc ')", + OneRowRelation().select('ltrim.function("c []", "[ ccccbcc ")) + ) + assertEqual( + "SELECT TRIM(TRAILING 'c&^,.' FROM 'bc...,,,&&&ccc')", + OneRowRelation().select('rtrim.function("c&^,.", "bc...,,,&&&ccc")) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 76be6ee3f50bc..cc80a41df998d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -51,7 +51,7 @@ class TableIdentifierParserSuite extends SparkFunSuite { "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position") + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing") val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index a69dd9718fe33..c4b519f0b153f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -100,72 +100,16 @@ public ArrayData copy() { public Object[] array() { DataType dt = data.dataType(); Object[] list = new Object[length]; - - if (dt instanceof BooleanType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getBoolean(offset + i); - } - } - } else if (dt instanceof ByteType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getByte(offset + i); - } - } - } else if (dt instanceof ShortType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getShort(offset + i); - } - } - } else if (dt instanceof IntegerType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getInt(offset + i); - } - } - } else if (dt instanceof FloatType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getFloat(offset + i); - } - } - } else if (dt instanceof DoubleType) { + try { for (int i = 0; i < length; i++) { if (!data.isNullAt(offset + i)) { - list[i] = data.getDouble(offset + i); + list[i] = get(i, dt); } } - } else if (dt instanceof LongType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = data.getLong(offset + i); - } - } - } else if (dt instanceof DecimalType) { - DecimalType decType = (DecimalType)dt; - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = getDecimal(i, decType.precision(), decType.scale()); - } - } - } else if (dt instanceof StringType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = getUTF8String(i).toString(); - } - } - } else if (dt instanceof CalendarIntervalType) { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = getInterval(i); - } - } - } else { - throw new UnsupportedOperationException("Type " + dt); + return list; + } catch(Exception e) { + throw new RuntimeException("Could not get the array", e); } - return list; } @Override @@ -237,7 +181,42 @@ public MapData getMap(int ordinal) { @Override public Object get(int ordinal, DataType dataType) { - throw new UnsupportedOperationException(); + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java new file mode 100644 index 0000000000000..dbcbe326a7510 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * The base interface for data source v2. Implementations must have a public, no arguments + * constructor. + * + * Note that this is an empty interface, data source implementations should mix-in at least one of + * the plug-in interfaces like {@link ReadSupport}. Otherwise it's just a dummy data source which is + * un-readable/writable. + */ +@InterfaceStability.Evolving +public interface DataSourceV2 {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java new file mode 100644 index 0000000000000..9a89c8193dd6e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2Options.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An immutable string-to-string map in which keys are case-insensitive. This is used to represent + * data source options. + */ +@InterfaceStability.Evolving +public class DataSourceV2Options { + private final Map keyLowerCasedMap; + + private String toLowerCase(String key) { + return key.toLowerCase(Locale.ROOT); + } + + public DataSourceV2Options(Map originalMap) { + keyLowerCasedMap = new HashMap<>(originalMap.size()); + for (Map.Entry entry : originalMap.entrySet()) { + keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); + } + } + + /** + * Returns the option value to which the specified key is mapped, case-insensitively. + */ + public Optional get(String key) { + return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java new file mode 100644 index 0000000000000..ab5254a688d5a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability and scan the data from the data source. + */ +@InterfaceStability.Evolving +public interface ReadSupport { + + /** + * Creates a {@link DataSourceV2Reader} to scan the data from this data source. + * + * @param options the options for this data source reader, which is an immutable case-insensitive + * string-to-string map. + * @return a reader that implements the actual read logic. + */ + DataSourceV2Reader createReader(DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java new file mode 100644 index 0000000000000..c13aeca2ef36f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data reading ability and scan the data from the data source. + * + * This is a variant of {@link ReadSupport} that accepts user-specified schema when reading data. + * A data source can implement both {@link ReadSupport} and {@link ReadSupportWithSchema} if it + * supports both schema inference and user-specified schema. + */ +@InterfaceStability.Evolving +public interface ReadSupportWithSchema { + + /** + * Create a {@link DataSourceV2Reader} to scan the data from this data source. + * + * @param schema the full schema of this data source reader. Full schema usually maps to the + * physical schema of the underlying storage of this data source reader, e.g. + * CSV files, JSON files, etc, while this reader may not read data with full + * schema, as column pruning or other optimizations may happen. + * @param options the options for this data source reader, which is an immutable case-insensitive + * string-to-string map. + * @return a reader that implements the actual read logic. + */ + DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java new file mode 100644 index 0000000000000..cfafc1a576793 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Closeable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data + * for a RDD partition. + */ +@InterfaceStability.Evolving +public interface DataReader extends Closeable { + + /** + * Proceed to next record, returns false if there is no more records. + */ + boolean next(); + + /** + * Return the current record. This method should return same value until `next` is called. + */ + T get(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java new file mode 100644 index 0000000000000..48feb049c1de9 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; +import org.apache.spark.sql.types.StructType; + +/** + * A data source reader that is returned by + * {@link ReadSupport#createReader(DataSourceV2Options)} or + * {@link ReadSupportWithSchema#createReader(StructType, DataSourceV2Options)}. + * It can mix in various query optimization interfaces to speed up the data scan. The actual scan + * logic should be delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * + * There are mainly 3 kinds of query optimizations: + * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column + * pruning), etc. These push-down interfaces are named like `SupportsPushDownXXX`. + * 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc. These + * reporting interfaces are named like `SupportsReportingXXX`. + * 3. Special scans. E.g, columnar scan, unsafe row scan, etc. These scan interfaces are named + * like `SupportsScanXXX`. + * + * Spark first applies all operator push-down optimizations that this data source supports. Then + * Spark collects information this data source reported for further optimizations. Finally Spark + * issues the scan request and does the actual data reading. + */ +@InterfaceStability.Evolving +public interface DataSourceV2Reader { + + /** + * Returns the actual schema of this data source reader, which may be different from the physical + * schema of the underlying storage, as column pruning or other optimizations may happen. + */ + StructType readSchema(); + + /** + * Returns a list of read tasks. Each task is responsible for outputting data for one RDD + * partition. That means the number of tasks returned here is same as the number of RDD + * partitions this scan outputs. + * + * Note that, this may not be a full scan if the data source reader mixes in other optimization + * interfaces like column pruning, filter push-down, etc. These optimizations are applied before + * Spark issues the scan request. + */ + List> createReadTasks(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java new file mode 100644 index 0000000000000..7885bfcdd49e4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A read task returned by {@link DataSourceV2Reader#createReadTasks()} and is responsible for + * creating the actual data reader. The relationship between {@link ReadTask} and {@link DataReader} + * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. + * + * Note that, the read task will be serialized and sent to executors, then the data reader will be + * created on executors and do the actual reading. + */ +@InterfaceStability.Evolving +public interface ReadTask extends Serializable { + + /** + * The preferred locations where this read task can run faster, but Spark does not guarantee that + * this task will always run on these locations. The implementations should make sure that it can + * be run on any location. The location is a string representing the host name of an executor. + */ + default String[] preferredLocations() { + return new String[0]; + } + + /** + * Returns a data reader to do the actual reading work for this read task. + */ + DataReader createReader(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java new file mode 100644 index 0000000000000..e8cd7adbca071 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.OptionalLong; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent statistics for a data source, which is returned by + * {@link SupportsReportStatistics#getStatistics()}. + */ +@InterfaceStability.Evolving +public interface Statistics { + OptionalLong sizeInBytes(); + OptionalLong numRows(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java new file mode 100644 index 0000000000000..19d706238ec8e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.catalyst.expressions.Expression; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to push down arbitrary expressions as predicates to the data source. + * This is an experimental and unstable interface as {@link Expression} is not public and may get + * changed in the future Spark versions. + * + * Note that, if data source readers implement both this interface and + * {@link SupportsPushDownFilters}, Spark will ignore {@link SupportsPushDownFilters} and only + * process this interface. + */ +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsPushDownCatalystFilters { + + /** + * Pushes down filters, and returns unsupported filters. + */ + Expression[] pushCatalystFilters(Expression[] filters); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java new file mode 100644 index 0000000000000..d4b509e7080f2 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.sources.Filter; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to push down filters to the data source and reduce the size of the data to be read. + * + * Note that, if data source readers implement both this interface and + * {@link SupportsPushDownCatalystFilters}, Spark will ignore this interface and only process + * {@link SupportsPushDownCatalystFilters}. + */ +@InterfaceStability.Evolving +public interface SupportsPushDownFilters { + + /** + * Pushes down filters, and returns unsupported filters. + */ + Filter[] pushFilters(Filter[] filters); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java new file mode 100644 index 0000000000000..fe0ac8ee0ee32 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to push down required columns to the data source and only read these columns during + * scan to reduce the size of the data to be read. + */ +@InterfaceStability.Evolving +public interface SupportsPushDownRequiredColumns { + + /** + * Applies column pruning w.r.t. the given requiredSchema. + * + * Implementation should try its best to prune the unnecessary columns or nested fields, but it's + * also OK to do the pruning partially, e.g., a data source may not be able to prune nested + * fields, and only prune top-level columns. + * + * Note that, data source readers should update {@link DataSourceV2Reader#readSchema()} after + * applying column pruning. + */ + void pruneColumns(StructType requiredSchema); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java new file mode 100644 index 0000000000000..c019d2f819ab7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to report statistics to Spark. + */ +@InterfaceStability.Evolving +public interface SupportsReportStatistics { + + /** + * Returns the basic statistics of this data source. + */ + Statistics getStatistics(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java new file mode 100644 index 0000000000000..829f9a078760b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import java.util.List; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.ReadTask; + +/** + * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to output {@link UnsafeRow} directly and avoid the row copy at Spark side. + * This is an experimental and unstable interface, as {@link UnsafeRow} is not public and may get + * changed in the future Spark versions. + */ +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsScanUnsafeRow extends DataSourceV2Reader { + + @Override + default List> createReadTasks() { + throw new IllegalStateException("createReadTasks should not be called with SupportsScanUnsafeRow."); + } + + /** + * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns data in unsafe row format. + */ + List> createUnsafeRowReadTasks(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c69acc413e87f..78b668c04fd5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,6 +32,8 @@ import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -180,13 +182,43 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { "read files of Hive data source directly.") } - sparkSession.baseRelationToDataFrame( - DataSource.apply( - sparkSession, - paths = paths, - userSpecifiedSchema = userSpecifiedSchema, - className = source, - options = extraOptions.toMap).resolveRelation()) + val cls = DataSource.lookupDataSource(source) + if (classOf[DataSourceV2].isAssignableFrom(cls)) { + val dataSource = cls.newInstance() + val options = new DataSourceV2Options(extraOptions.asJava) + + val reader = (cls.newInstance(), userSpecifiedSchema) match { + case (ds: ReadSupportWithSchema, Some(schema)) => + ds.createReader(schema, options) + + case (ds: ReadSupport, None) => + ds.createReader(options) + + case (_: ReadSupportWithSchema, None) => + throw new AnalysisException(s"A schema needs to be specified when using $dataSource.") + + case (ds: ReadSupport, Some(schema)) => + val reader = ds.createReader(options) + if (reader.readSchema() != schema) { + throw new AnalysisException(s"$ds does not allow user-specified schemas.") + } + reader + + case _ => + throw new AnalysisException(s"$cls does not support data reading.") + } + + Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + } else { + // Code path for data source v1. + sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + userSpecifiedSchema = userSpecifiedSchema, + className = source, + options = extraOptions.toMap).resolveRelation()) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 2118b9118a22f..2a2315896831c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import com.fasterxml.jackson.annotation.JsonIgnoreProperties + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo @@ -26,6 +28,7 @@ import org.apache.spark.sql.execution.metric.SQLMetricInfo * Stores information about a SQL SparkPlan. */ @DeveloperApi +@JsonIgnoreProperties(Array("metadata")) // The metadata field was removed in Spark 2.3. class SparkPlanInfo( val nodeName: String, val simpleString: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 4e718d609c921..b143d44eae17b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, FileSourceStrategy} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Strategy import org.apache.spark.sql.internal.SQLConf class SparkPlanner( @@ -35,6 +36,7 @@ class SparkPlanner( def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( + DataSourceV2Strategy :: FileSourceStrategy :: DataSourceStrategy(conf) :: SpecialLimits :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 11ba04d2ce9a7..0b740735ffe19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -234,8 +234,9 @@ private[arrow] class StringWriter(val valueVector: NullableVarCharVector) extend override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { val utf8 = input.getUTF8String(ordinal) + val utf8ByteBuffer = utf8.getByteBuffer // todo: for off-heap UTF8String, how to pass in to arrow without copy? - valueMutator.setSafe(count, utf8.getByteBuffer, 0, utf8.numBytes()) + valueMutator.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), utf8.numBytes()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index 6588993ef9ad9..caf12ad745bb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -56,9 +56,6 @@ case class AnalyzeColumnCommand( sessionState.catalog.alterTableStats(tableIdentWithDB, Some(statistics)) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) - Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala index 04715bd314d4d..58b53e8b1c551 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala @@ -48,8 +48,6 @@ case class AnalyzeTableCommand( val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount) if (newStats.isDefined) { sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) - // Refresh the cached data source table in the catalog. - sessionState.catalog.refreshTable(tableIdentWithDB) } Seq.empty[Row] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 75327f0d38c2e..71133666b3249 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -301,12 +301,11 @@ object JdbcUtils extends Logging { } else { rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls } - val metadata = new MetadataBuilder() - .putLong("scale", fieldScale) + val metadata = new MetadataBuilder().putLong("scale", fieldScale) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( getCatalystType(dataType, fieldSize, fieldScale, isSigned)) - fields(i) = StructField(columnName, columnType, nullable, metadata.build()) + fields(i) = StructField(columnName, columnType, nullable) i = i + 1 } new StructType(fields) @@ -768,31 +767,30 @@ object JdbcUtils extends Logging { } /** - * Parses the user specified customSchema option value to DataFrame schema, - * and returns it if it's all columns are equals to default schema's. + * Parses the user specified customSchema option value to DataFrame schema, and + * returns a schema that is replaced by the custom schema's dataType if column name is matched. */ def getCustomSchema( tableSchema: StructType, customSchema: String, nameEquality: Resolver): StructType = { - val userSchema = CatalystSqlParser.parseTableSchema(customSchema) + if (null != customSchema && customSchema.nonEmpty) { + val userSchema = CatalystSqlParser.parseTableSchema(customSchema) - SchemaUtils.checkColumnNameDuplication( - userSchema.map(_.name), "in the customSchema option value", nameEquality) - - val colNames = tableSchema.fieldNames.mkString(",") - val errorMsg = s"Please provide all the columns, all columns are: $colNames" - if (userSchema.size != tableSchema.size) { - throw new AnalysisException(errorMsg) - } + SchemaUtils.checkColumnNameDuplication( + userSchema.map(_.name), "in the customSchema option value", nameEquality) - // This is resolved by names, only check the column names. - userSchema.fieldNames.foreach { col => - tableSchema.find(f => nameEquality(f.name, col)).getOrElse { - throw new AnalysisException(errorMsg) + // This is resolved by names, use the custom filed dataType to replace the default dataType. + val newSchema = tableSchema.map { col => + userSchema.find(f => nameEquality(f.name, col.name)) match { + case Some(c) => col.copy(dataType = c.dataType) + case None => col + } } + StructType(newSchema) + } else { + tableSchema } - userSchema } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala new file mode 100644 index 0000000000000..b8fe5ac8e3d94 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.v2.reader.ReadTask + +class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow]) + extends Partition with Serializable + +class DataSourceRDD( + sc: SparkContext, + @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]]) + extends RDD[UnsafeRow](sc, Nil) { + + override protected def getPartitions: Array[Partition] = { + readTasks.asScala.zipWithIndex.map { + case (readTask, index) => new DataSourceRDDPartition(index, readTask) + }.toArray + } + + override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = { + val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createReader() + context.addTaskCompletionListener(_ => reader.close()) + val iter = new Iterator[UnsafeRow] { + private[this] var valuePrepared = false + + override def hasNext: Boolean = { + if (!valuePrepared) { + valuePrepared = reader.next() + } + valuePrepared + } + + override def next(): UnsafeRow = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + valuePrepared = false + reader.get() + } + } + new InterruptibleIterator(context, iter) + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala new file mode 100644 index 0000000000000..3c9b598fd07c9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.sources.v2.reader.{DataSourceV2Reader, SupportsReportStatistics} + +case class DataSourceV2Relation( + output: Seq[AttributeReference], + reader: DataSourceV2Reader) extends LeafNode { + + override def computeStats(): Statistics = reader match { + case r: SupportsReportStatistics => + Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + case _ => + Statistics(sizeInBytes = conf.defaultSizeInBytes) + } +} + +object DataSourceV2Relation { + def apply(reader: DataSourceV2Reader): DataSourceV2Relation = { + new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala new file mode 100644 index 0000000000000..7999c0ceb5749 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.types.StructType + +case class DataSourceV2ScanExec( + fullOutput: Array[AttributeReference], + @transient reader: DataSourceV2Reader, + // TODO: these 3 parameters are only used to determine the equality of the scan node, however, + // the reader also have this information, and ideally we can just rely on the equality of the + // reader. The only concern is, the reader implementation is outside of Spark and we have no + // control. + readSchema: StructType, + @transient filters: ExpressionSet, + hashPartitionKeys: Seq[String]) extends LeafExecNode { + + def output: Seq[Attribute] = readSchema.map(_.name).map { name => + fullOutput.find(_.name == name).get + } + + override def references: AttributeSet = AttributeSet.empty + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + + override protected def doExecute(): RDD[InternalRow] = { + val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { + case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() + case _ => + reader.createReadTasks().asScala.map { + new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow] + }.asJava + } + + val inputRDD = new DataSourceRDD(sparkContext, readTasks) + .asInstanceOf[RDD[InternalRow]] + val numOutputRows = longMetric("numOutputRows") + inputRDD.map { r => + numOutputRows += 1 + r + } + } +} + +class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) + extends ReadTask[UnsafeRow] { + + override def preferredLocations: Array[String] = rowReadTask.preferredLocations + + override def createReader: DataReader[UnsafeRow] = { + new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema)) + } +} + +class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row]) + extends DataReader[UnsafeRow] { + + override def next: Boolean = rowReader.next + + override def get: UnsafeRow = encoder.toRow(rowReader.get).asInstanceOf[UnsafeRow] + + override def close(): Unit = rowReader.close() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala new file mode 100644 index 0000000000000..b80f695b2a87f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.sources.v2.reader._ + +object DataSourceV2Strategy extends Strategy { + // TODO: write path + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalOperation(projects, filters, DataSourceV2Relation(output, reader)) => + val stayUpFilters: Seq[Expression] = reader match { + case r: SupportsPushDownCatalystFilters => + r.pushCatalystFilters(filters.toArray) + + case r: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, Filter] = filters.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet + val unhandledPredicates = translatedMap.filter { case (_, f) => + unhandledFilters.contains(f) + }.keys + + nonConvertiblePredicates ++ unhandledPredicates + + case _ => filters + } + + val attrMap = AttributeMap(output.zip(output)) + val projectSet = AttributeSet(projects.flatMap(_.references)) + val filterSet = AttributeSet(stayUpFilters.flatMap(_.references)) + + // Match original case of attributes. + // TODO: nested fields pruning + val requiredColumns = (projectSet ++ filterSet).toSeq.map(attrMap) + reader match { + case r: SupportsPushDownRequiredColumns => + r.pruneColumns(requiredColumns.toStructType) + case _ => + } + + val scan = DataSourceV2ScanExec( + output.toArray, + reader, + reader.readSchema(), + ExpressionSet(filters), + Nil) + + val filterCondition = stayUpFilters.reduceLeftOption(And) + val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) + + val withProject = if (projects == withFilter.output) { + withFilter + } else { + ProjectExec(projects, withFilter) + } + + withProject :: Nil + + case _ => Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 258a64216136f..027222e1119c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -21,11 +21,13 @@ import java.util.UUID import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.internal.Logging -import org.apache.spark.sql.{SparkSession, Strategy} +import org.apache.spark.sql.{AnalysisException, SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.CurrentBatchTimestamp import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.streaming.OutputMode /** @@ -39,7 +41,7 @@ class IncrementalExecution( val checkpointLocation: String, val runId: UUID, val currentBatchId: Long, - offsetSeqMetadata: OffsetSeqMetadata) + val offsetSeqMetadata: OffsetSeqMetadata) extends QueryExecution(sparkSession, logicalPlan) with Logging { // Modified planner with stateful operations. @@ -89,7 +91,7 @@ class IncrementalExecution( override def apply(plan: SparkPlan): SparkPlan = plan transform { case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, - StateStoreRestoreExec(keys2, None, child))) => + StateStoreRestoreExec(_, None, child))) => val aggStateInfo = nextStatefulOperationStateInfo StateStoreSaveExec( keys, @@ -117,8 +119,34 @@ class IncrementalExecution( } } - override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations + override def preparations: Seq[Rule[SparkPlan]] = + Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } + +object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { + // Needs to be transformUp to avoid extra shuffles + override def apply(plan: SparkPlan): SparkPlan = plan transformUp { + case so: StatefulOperator => + val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions + val distributions = so.requiredChildDistribution + val children = so.children.zip(distributions).map { case (child, reqDistribution) => + val expectedPartitioning = reqDistribution match { + case AllTuples => SinglePartition + case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions) + case _ => throw new AnalysisException("Unexpected distribution expected for " + + s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + + s"$reqDistribution.") + } + if (child.outputPartitioning.guarantees(expectedPartitioning) && + child.execute().getNumPartitions == expectedPartitioning.numPartitions) { + child + } else { + ShuffleExchange(expectedPartitioning, child) + } + } + so.withNewChildren(children) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala index 5551d12fa8ad2..b84e6ce64c611 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetricsReporter.scala @@ -40,7 +40,7 @@ class MetricsReporter( // Metric names should not have . in them, so that all the metrics of a query are identified // together in Ganglia as a single metric group registerGauge("inputRate-total", () => stream.lastProgress.inputRowsPerSecond) - registerGauge("processingRate-total", () => stream.lastProgress.inputRowsPerSecond) + registerGauge("processingRate-total", () => stream.lastProgress.processedRowsPerSecond) registerGauge("latency", () => stream.lastProgress.durationMs.get("triggerExecution").longValue()) private def registerGauge[T](name: String, f: () => T)(implicit num: Numeric[T]): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 952e431fb19d3..18385f5fc1975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -130,6 +130,16 @@ class StreamExecution( protected var offsetSeqMetadata = OffsetSeqMetadata( batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf) + /** + * A map of current watermarks, keyed by the position of the watermark operator in the + * physical plan. + * + * This state is 'soft state', which does not affect the correctness and semantics of watermarks + * and is not persisted across query restarts. + * The fault-tolerant watermark state is in offsetSeqMetadata. + */ + protected val watermarkMsMap: MutableMap[Int, Long] = MutableMap() + override val id: UUID = UUID.fromString(streamMetadata.id) override val runId: UUID = UUID.randomUUID @@ -560,13 +570,32 @@ class StreamExecution( } if (hasNewData) { var batchWatermarkMs = offsetSeqMetadata.batchWatermarkMs - // Update the eventTime watermark if we find one in the plan. + // Update the eventTime watermarks if we find any in the plan. if (lastExecution != null) { lastExecution.executedPlan.collect { - case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 => - logDebug(s"Observed event time stats: ${e.eventTimeStats.value}") - e.eventTimeStats.value.max - e.delayMs - }.headOption.foreach { newWatermarkMs => + case e: EventTimeWatermarkExec => e + }.zipWithIndex.foreach { + case (e, index) if e.eventTimeStats.value.count > 0 => + logDebug(s"Observed event time stats $index: ${e.eventTimeStats.value}") + val newWatermarkMs = e.eventTimeStats.value.max - e.delayMs + val prevWatermarkMs = watermarkMsMap.get(index) + if (prevWatermarkMs.isEmpty || newWatermarkMs > prevWatermarkMs.get) { + watermarkMsMap.put(index, newWatermarkMs) + } + + // Populate 0 if we haven't seen any data yet for this watermark node. + case (_, index) => + if (!watermarkMsMap.isDefinedAt(index)) { + watermarkMsMap.put(index, 0) + } + } + + // Update the global watermark to the minimum of all watermark nodes. + // This is the safest option, because only the global watermark is fault-tolerant. Making + // it the minimum of all individual watermarks guarantees it will never advance past where + // any individual watermark operator would be if it were in a plan by itself. + if(!watermarkMsMap.isEmpty) { + val newWatermarkMs = watermarkMsMap.minBy(_._2)._2 if (newWatermarkMs > batchWatermarkMs) { logInfo(s"Updating eventTime watermark to: $newWatermarkMs ms") batchWatermarkMs = newWatermarkMs @@ -800,6 +829,7 @@ class StreamExecution( if (streamDeathCause != null) { throw streamDeathCause } + if (!isActive) return awaitBatchLock.lock() try { noNewData = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala index 4207013c3f75d..07e39023c8366 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingQueryListenerBus.scala @@ -40,7 +40,7 @@ class StreamingQueryListenerBus(sparkListenerBus: LiveListenerBus) import StreamingQueryListener._ - sparkListenerBus.addListener(this) + sparkListenerBus.addToSharedQueue(this) /** * RunIds of active queries whose events are supposed to be forwarded by this ListenerBus diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index e46356392c51b..d6566b8e6b54f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -200,11 +200,20 @@ case class StateStoreRestoreExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) - iter.flatMap { row => - val key = getKey(row) - val savedState = store.get(key) - numOutputRows += 1 - row +: Option(savedState).toSeq + val hasInput = iter.hasNext + if (!hasInput && keyExpressions.isEmpty) { + // If our `keyExpressions` are empty, we're getting a global aggregation. In that case + // the `HashAggregateExec` will output a 0 value for the partial merge. We need to + // restore the value, so that we don't overwrite our state with a 0 value, but rather + // merge the 0 with existing state. + store.iterator().map(_.value) + } else { + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + numOutputRows += 1 + row +: Option(savedState).toSeq + } } } } @@ -212,6 +221,14 @@ case class StateStoreRestoreExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions) :: Nil + } + } } /** @@ -351,6 +368,14 @@ case class StateStoreSaveExec( override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (keyExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(keyExpressions) :: Nil + } + } } /** Physical operator for executing streaming Deduplicate. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47324ed9f2fb8..c6d0d86384b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2333,6 +2333,15 @@ object functions { */ def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } + /** + * Trim the specified character string from left end for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def ltrim(e: Column, trimString: String): Column = withExpr { + StringTrimLeft(e.expr, Literal(trimString)) + } + /** * Extract a specific group matched by a Java regex, from the specified string column. * If the regex did not match, or the specified group did not match, an empty string is returned. @@ -2410,6 +2419,15 @@ object functions { */ def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } + /** + * Trim the specified character string from right end for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def rtrim(e: Column, trimString: String): Column = withExpr { + StringTrimRight(e.expr, Literal(trimString)) + } + /** * Returns the soundex code for the specified expression. * @@ -2477,6 +2495,15 @@ object functions { */ def trim(e: Column): Column = withExpr { StringTrim(e.expr) } + /** + * Trim the specified character from both ends for the specified string column. + * @group string_funcs + * @since 2.3.0 + */ + def trim(e: Column, trimString: String): Column = withExpr { + StringTrim(e.expr, Literal(trimString)) + } + /** * Converts a string column to upper case. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 7202f1222d10f..ad9db308b2627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.CacheManager @@ -148,7 +149,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { if (SparkSession.sqlListener.get() == null) { val listener = new SQLListener(sc.conf) if (SparkSession.sqlListener.compareAndSet(null, listener)) { - sc.addSparkListener(listener) + sc.listenerBus.addToStatusQueue(listener) sc.ui.foreach(new SQLTab(listener, _)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala index 467d8d62d1b7f..7432a1538ce97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -41,4 +41,8 @@ private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect override def getJDBCType(dt: DataType): Option[JdbcType] = { dialects.flatMap(_.getJDBCType(dt)).headOption } + + override def isCascadingTruncateTable(): Option[Boolean] = { + dialects.flatMap(_.isCascadingTruncateTable()).reduceOption(_ || _) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java new file mode 100644 index 0000000000000..50900e98dedb6 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.*; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.Filter; +import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { + private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + private Filter[] filters = new Filter[0]; + + @Override + public StructType readSchema() { + return requiredSchema; + } + + @Override + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public Filter[] pushFilters(Filter[] filters) { + this.filters = filters; + return new Filter[0]; + } + + @Override + public List> createReadTasks() { + List> res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaAdvancedReadTask(0, 5, requiredSchema)); + res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + } else if (lowerBound < 4) { + res.add(new JavaAdvancedReadTask(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedReadTask(5, 10, requiredSchema)); + } else if (lowerBound < 9) { + res.add(new JavaAdvancedReadTask(lowerBound + 1, 10, requiredSchema)); + } + + return res; + } + } + + static class JavaAdvancedReadTask implements ReadTask, DataReader { + private int start; + private int end; + private StructType requiredSchema; + + JavaAdvancedReadTask(int start, int end, StructType requiredSchema) { + this.start = start; + this.end = end; + this.requiredSchema = requiredSchema; + } + + @Override + public DataReader createReader() { + return new JavaAdvancedReadTask(start - 1, end, requiredSchema); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public Row get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = start; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -start; + } + } + return new GenericRow(values); + } + + @Override + public void close() throws IOException { + + } + } + + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java new file mode 100644 index 0000000000000..a174bd8092cbd --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupportWithSchema; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.types.StructType; + +public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema { + + class Reader implements DataSourceV2Reader { + private final StructType schema; + + Reader(StructType schema) { + this.schema = schema; + } + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Collections.emptyList(); + } + } + + @Override + public DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options) { + return new Reader(schema); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java new file mode 100644 index 0000000000000..08469f14c257a --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.DataReader; +import org.apache.spark.sql.sources.v2.reader.ReadTask; +import org.apache.spark.sql.sources.v2.reader.DataSourceV2Reader; +import org.apache.spark.sql.types.StructType; + +public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Arrays.asList( + new JavaSimpleReadTask(0, 5), + new JavaSimpleReadTask(5, 10)); + } + } + + static class JavaSimpleReadTask implements ReadTask, DataReader { + private int start; + private int end; + + JavaSimpleReadTask(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public DataReader createReader() { + return new JavaSimpleReadTask(start - 1, end); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public Row get() { + return new GenericRow(new Object[] {start, -start}); + } + + @Override + public void close() throws IOException { + + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java new file mode 100644 index 0000000000000..9efe7c791a936 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaUnsafeRowDataSourceV2.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaUnsafeRowDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsScanUnsafeRow { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createUnsafeRowReadTasks() { + return java.util.Arrays.asList( + new JavaUnsafeRowReadTask(0, 5), + new JavaUnsafeRowReadTask(5, 10)); + } + } + + static class JavaUnsafeRowReadTask implements ReadTask, DataReader { + private int start; + private int end; + private UnsafeRow row; + + JavaUnsafeRowReadTask(int start, int end) { + this.start = start; + this.end = end; + this.row = new UnsafeRow(2); + row.pointTo(new byte[8 * 3], 8 * 3); + } + + @Override + public DataReader createReader() { + return new JavaUnsafeRowReadTask(start - 1, end); + } + + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public UnsafeRow get() { + row.setInt(0, start); + row.setInt(1, -start); + return row; + } + + @Override + public void close() throws IOException { + + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index dcced79d315f3..d9dc728a18e8d 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -26,13 +26,13 @@ Extended Usage: {"time":"26/08/2015"} > SELECT to_json(array(named_struct('a', 1, 'b', 2)); [{"a":1,"b":2}] - > SELECT to_json(map('a',named_struct('b',1))); + > SELECT to_json(map('a', named_struct('b', 1))); {"a":{"b":1}} - > SELECT to_json(map(named_struct('a',1),named_struct('b',2))); + > SELECT to_json(map(named_struct('a', 1),named_struct('b', 2))); {"[1]":{"b":2}} - > SELECT to_json(map('a',1)); + > SELECT to_json(map('a', 1)); {"a":1} - > SELECT to_json(array((map('a',1)))); + > SELECT to_json(array((map('a', 1)))); [{"a":1}] Since: 2.2.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 9e459ed00c8d5..2fc92f4aff92e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -261,6 +261,10 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared assert(fetched1.get.sizeInBytes == 0) assert(fetched1.get.colStats.size == 2) + // table lookup will make the table cached + spark.table(table) + assert(isTableInCatalogCache(table)) + // insert into command sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") if (autoUpdate) { @@ -270,9 +274,78 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } else { checkTableStats(table, hasSizeInBytes = false, expectedRowCounts = None) } + + // check that tableRelationCache inside the catalog was invalidated after insert + assert(!isTableInCatalogCache(table)) + } + } + } + } + + test("invalidation of tableRelationCache after inserts") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).write.saveAsTable(table) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + val initialSizeInBytes = getTableFromCatalogCache(table).stats.sizeInBytes + spark.range(100).write.mode(SaveMode.Append).saveAsTable(table) + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 2 * initialSizeInBytes) + } + } + } + } + + test("invalidation of tableRelationCache after table truncation") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTable(table) { + spark.range(100).write.saveAsTable(table) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + sql(s"TRUNCATE TABLE $table") + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 0) } } } } + test("invalidation of tableRelationCache after alter table add partition") { + val table = "invalidate_catalog_cache_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_UPDATE_SIZE.key -> autoUpdate.toString) { + withTempDir { dir => + withTable(table) { + val path = dir.getCanonicalPath + sql(s""" + |CREATE TABLE $table (col1 int, col2 int) + |USING PARQUET + |PARTITIONED BY (col2) + |LOCATION '${dir.toURI}'""".stripMargin) + sql(s"ANALYZE TABLE $table COMPUTE STATISTICS") + spark.table(table) + assert(getTableFromCatalogCache(table).stats.sizeInBytes == 0) + spark.catalog.recoverPartitions(table) + val df = Seq((1, 2), (1, 2)).toDF("col2", "col1") + df.write.parquet(s"$path/col2=1") + sql(s"ALTER TABLE $table ADD PARTITION (col2=1) LOCATION '${dir.toURI}'") + spark.table(table) + val cachedTable = getTableFromCatalogCache(table) + val cachedTableSizeInBytes = cachedTable.stats.sizeInBytes + val defaultSizeInBytes = conf.defaultSizeInBytes + if (autoUpdate) { + assert(cachedTableSizeInBytes != defaultSizeInBytes && cachedTableSizeInBytes > 0) + } else { + assert(cachedTableSizeInBytes == defaultSizeInBytes) + } + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 5916cd76b8789..a2f63edd786bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -23,9 +23,9 @@ import java.sql.{Date, Timestamp} import scala.collection.mutable import scala.util.Random -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, HiveTableRelation} -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf @@ -85,6 +85,16 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)) } + def getTableFromCatalogCache(tableName: String): LogicalPlan = { + val catalog = spark.sessionState.catalog + val qualifiedTableName = QualifiedTableName(catalog.getCurrentDatabase, tableName) + catalog.getCachedTable(qualifiedTableName) + } + + def isTableInCatalogCache(tableName: String): Boolean = { + getTableFromCatalogCache(tableName) != null + } + def getCatalogStatistics(tableName: String): CatalogStatistics = { getCatalogTable(tableName).stats.get } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index a12efc835691b..3d76b9ac33e57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -161,12 +161,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("string trim functions") { - val df = Seq((" example ", "")).toDF("a", "b") + val df = Seq((" example ", "", "example")).toDF("a", "b", "c") checkAnswer( df.select(ltrim($"a"), rtrim($"a"), trim($"a")), Row("example ", " example", "example")) + checkAnswer( + df.select(ltrim($"c", "e"), rtrim($"c", "e"), trim($"c", "e")), + Row("xample", "exampl", "xampl")) + + checkAnswer( + df.select(ltrim($"c", "xe"), rtrim($"c", "emlp"), trim($"c", "elxp")), + Row("ample", "exa", "am")) + + checkAnswer( + df.select(trim($"c", "xyz")), + Row("example")) + checkAnswer( df.selectExpr("ltrim(a)", "rtrim(a)", "trim(a)"), Row("example ", " example", "example")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala new file mode 100644 index 0000000000000..c2e62b987e0cc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.json4s.jackson.JsonMethods.parse + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart +import org.apache.spark.util.JsonProtocol + +class SQLJsonProtocolSuite extends SparkFunSuite { + + test("SparkPlanGraph backward compatibility: metadata") { + val SQLExecutionStartJsonString = + """ + |{ + | "Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", + | "executionId":0, + | "description":"test desc", + | "details":"test detail", + | "physicalPlanDescription":"test plan", + | "sparkPlanInfo": { + | "nodeName":"TestNode", + | "simpleString":"test string", + | "children":[], + | "metadata":{}, + | "metrics":[] + | }, + | "time":0 + |} + """.stripMargin + val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) + val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan", + new SparkPlanInfo("TestNode", "test string", Nil, Nil), 0) + assert(reconstructedEvent == expectedEvent) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala index 1255f262bce94..7d277c1ffaffe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala @@ -30,57 +30,38 @@ class JdbcUtilsSuite extends SparkFunSuite { val caseInsensitive = org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution test("Parse user specified column types") { - assert( - JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseInsensitive) === - StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true)))) - assert(JdbcUtils.getCustomSchema(tableSchema, "C1 DATE, C2 STRING", caseSensitive) === - StructType(Seq(StructField("C1", DateType, true), StructField("C2", StringType, true)))) + assert(JdbcUtils.getCustomSchema(tableSchema, null, caseInsensitive) === tableSchema) + assert(JdbcUtils.getCustomSchema(tableSchema, "", caseInsensitive) === tableSchema) + + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE", caseInsensitive) === + StructType(Seq(StructField("C1", DateType, false), StructField("C2", IntegerType, false)))) + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE", caseSensitive) === + StructType(Seq(StructField("C1", StringType, false), StructField("C2", IntegerType, false)))) + assert( JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true)))) - assert(JdbcUtils.getCustomSchema( - tableSchema, "c1 DECIMAL(38, 0), C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c1", DecimalType(38, 0), true), - StructField("C2", StringType, true)))) + StructType(Seq(StructField("C1", DateType, false), StructField("C2", StringType, false)))) + assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) === + StructType(Seq(StructField("C1", StringType, false), StructField("C2", StringType, false)))) // Throw AnalysisException val duplicate = intercept[AnalysisException]{ JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, c1 STRING", caseInsensitive) === - StructType(Seq(StructField("c1", DateType, true), StructField("c1", StringType, true))) + StructType(Seq(StructField("c1", DateType, false), StructField("c1", StringType, false))) } assert(duplicate.getMessage.contains( "Found duplicate column(s) in the customSchema option value")) - val allColumns = intercept[AnalysisException]{ - JdbcUtils.getCustomSchema(tableSchema, "C1 STRING", caseSensitive) === - StructType(Seq(StructField("C1", DateType, true))) - } - assert(allColumns.getMessage.contains("Please provide all the columns,")) - - val caseSensitiveColumnNotFound = intercept[AnalysisException]{ - JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) === - StructType(Seq(StructField("c1", DateType, true), StructField("C2", StringType, true))) - } - assert(caseSensitiveColumnNotFound.getMessage.contains( - "Please provide all the columns, all columns are: C1,C2;")) - - val caseInsensitiveColumnNotFound = intercept[AnalysisException]{ - JdbcUtils.getCustomSchema(tableSchema, "c3 DATE, C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true))) - } - assert(caseInsensitiveColumnNotFound.getMessage.contains( - "Please provide all the columns, all columns are: C1,C2;")) - // Throw ParseException val dataTypeNotSupported = intercept[ParseException]{ JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true))) + StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false))) } assert(dataTypeNotSupported.getMessage.contains("DataType datee is not supported")) val mismatchedInput = intercept[ParseException]{ JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) === - StructType(Seq(StructField("c3", DateType, true), StructField("C2", StringType, true))) + StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false))) } assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala new file mode 100644 index 0000000000000..998067a47033b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { + + var testVector: WritableColumnVector = _ + + private def allocate(capacity: Int, dt: DataType): WritableColumnVector = { + new OnHeapColumnVector(capacity, dt) + } + + override def afterEach(): Unit = { + testVector.close() + } + + test("boolean") { + testVector = allocate(10, BooleanType) + (0 until 10).foreach { i => + testVector.appendBoolean(i % 2 == 0) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, BooleanType) === (i % 2 == 0)) + } + } + + test("byte") { + testVector = allocate(10, ByteType) + (0 until 10).foreach { i => + testVector.appendByte(i.toByte) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, ByteType) === (i.toByte)) + } + } + + test("short") { + testVector = allocate(10, ShortType) + (0 until 10).foreach { i => + testVector.appendShort(i.toShort) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, ShortType) === (i.toShort)) + } + } + + test("int") { + testVector = allocate(10, IntegerType) + (0 until 10).foreach { i => + testVector.appendInt(i) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, IntegerType) === i) + } + } + + test("long") { + testVector = allocate(10, LongType) + (0 until 10).foreach { i => + testVector.appendLong(i) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, LongType) === i) + } + } + + test("float") { + testVector = allocate(10, FloatType) + (0 until 10).foreach { i => + testVector.appendFloat(i.toFloat) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, FloatType) === i.toFloat) + } + } + + test("double") { + testVector = allocate(10, DoubleType) + (0 until 10).foreach { i => + testVector.appendDouble(i.toDouble) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, DoubleType) === i.toDouble) + } + } + + test("string") { + testVector = allocate(10, StringType) + (0 until 10).map { i => + val utf8 = s"str$i".getBytes("utf8") + testVector.appendByteArray(utf8, 0, utf8.length) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + assert(array.get(i, StringType) === UTF8String.fromString(s"str$i")) + } + } + + test("binary") { + testVector = allocate(10, BinaryType) + (0 until 10).map { i => + val utf8 = s"str$i".getBytes("utf8") + testVector.appendByteArray(utf8, 0, utf8.length) + } + + val array = new ColumnVector.Array(testVector) + + (0 until 10).foreach { i => + val utf8 = s"str$i".getBytes("utf8") + assert(array.get(i, BinaryType) === utf8) + } + } + + test("array") { + val arrayType = ArrayType(IntegerType, true) + testVector = allocate(10, arrayType) + + val data = testVector.arrayData() + var i = 0 + while (i < 6) { + data.putInt(i, i) + i += 1 + } + + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + testVector.putArray(0, 0, 1) + testVector.putArray(1, 1, 2) + testVector.putArray(2, 3, 0) + testVector.putArray(3, 3, 3) + + val array = new ColumnVector.Array(testVector) + + assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(0)) + assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(1, 2)) + assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === Array.empty[Int]) + assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) + } + + test("struct") { + val schema = new StructType().add("int", IntegerType).add("double", DoubleType) + testVector = allocate(10, schema) + val c1 = testVector.getChildColumn(0) + val c2 = testVector.getChildColumn(1) + c1.putInt(0, 123) + c2.putDouble(0, 3.45) + c1.putInt(1, 456) + c2.putDouble(1, 5.67) + + val array = new ColumnVector.Array(testVector) + + assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) + assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) + assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) + assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 40179261ab200..fd12bb9e530b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand @@ -739,11 +740,13 @@ class JDBCSuite extends SparkFunSuite } else { None } + override def isCascadingTruncateTable(): Option[Boolean] = Some(true) }, testH2Dialect)) assert(agg.canHandle("jdbc:h2:xxx")) assert(!agg.canHandle("jdbc:h2")) assert(agg.getCatalystType(0, "", 1, null) === Some(LongType)) assert(agg.getCatalystType(1, "", 1, null) === Some(StringType)) + assert(agg.isCascadingTruncateTable() === Some(true)) } test("DB2Dialect type mapping") { @@ -970,30 +973,28 @@ class JDBCSuite extends SparkFunSuite test("jdbc API support custom schema") { val parts = Array[String]("THEID < 2", "THEID >= 2") + val customSchema = "NAME STRING, THEID INT" val props = new Properties() - props.put("customSchema", "NAME STRING, THEID BIGINT") - val schema = StructType(Seq( - StructField("NAME", StringType, true), StructField("THEID", LongType, true))) + props.put("customSchema", customSchema) val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props) assert(df.schema.size === 2) - assert(df.schema === schema) + assert(df.schema === CatalystSqlParser.parseTableSchema(customSchema)) assert(df.count() === 3) } test("jdbc API custom schema DDL-like strings.") { withTempView("people_view") { + val customSchema = "NAME STRING, THEID INT" sql( s""" |CREATE TEMPORARY VIEW people_view |USING org.apache.spark.sql.jdbc |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass', - |customSchema 'NAME STRING, THEID INT') + |customSchema '$customSchema') """.stripMargin.replaceAll("\n", " ")) - val schema = StructType( - Seq(StructField("NAME", StringType, true), StructField("THEID", IntegerType, true))) val df = sql("select * from people_view") - assert(df.schema.size === 2) - assert(df.schema === schema) + assert(df.schema.length === 2) + assert(df.schema === CatalystSqlParser.parseTableSchema(customSchema)) assert(df.count() === 3) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala new file mode 100644 index 0000000000000..933f4075bcc8a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2OptionsSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite + +/** + * A simple test suite to verify `DataSourceV2Options`. + */ +class DataSourceV2OptionsSuite extends SparkFunSuite { + + test("key is case-insensitive") { + val options = new DataSourceV2Options(Map("foo" -> "bar").asJava) + assert(options.get("foo").get() == "bar") + assert(options.get("FoO").get() == "bar") + assert(!options.get("abc").isPresent) + } + + test("value is case-sensitive") { + val options = new DataSourceV2Options(Map("foo" -> "bAr").asJava) + assert(options.get("foo").get == "bAr") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala new file mode 100644 index 0000000000000..9ce93d7ae926c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util.{ArrayList, List => JList} + +import test.org.apache.spark.sql.sources.v2._ + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.sources.{Filter, GreaterThan} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType + +class DataSourceV2Suite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("simplest implementation") { + Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) + } + } + } + + test("advanced implementation") { + Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i))) + checkAnswer(df.select('i).filter('i > 10), Nil) + } + } + } + + test("unsafe row implementation") { + Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) + } + } + } + + test("schema required data source") { + Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls => + withClue(cls.getName) { + val e = intercept[AnalysisException](spark.read.format(cls.getName).load()) + assert(e.message.contains("A schema needs to be specified")) + + val schema = new StructType().add("i", "int").add("s", "string") + val df = spark.read.format(cls.getName).schema(schema).load() + + assert(df.schema == schema) + assert(df.collect().isEmpty) + } + } + } +} + +class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createReadTasks(): JList[ReadTask[Row]] = { + java.util.Arrays.asList(new SimpleReadTask(0, 5), new SimpleReadTask(5, 10)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class SimpleReadTask(start: Int, end: Int) extends ReadTask[Row] with DataReader[Row] { + private var current = start - 1 + + override def createReader(): DataReader[Row] = new SimpleReadTask(start, end) + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): Row = Row(current, -current) + + override def close(): Unit = {} +} + + + +class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { + + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + this.filters = filters + Array.empty + } + + override def readSchema(): StructType = { + requiredSchema + } + + override def createReadTasks(): JList[ReadTask[Row]] = { + val lowerBound = filters.collect { + case GreaterThan("i", v: Int) => v + }.headOption + + val res = new ArrayList[ReadTask[Row]] + + if (lowerBound.isEmpty) { + res.add(new AdvancedReadTask(0, 5, requiredSchema)) + res.add(new AdvancedReadTask(5, 10, requiredSchema)) + } else if (lowerBound.get < 4) { + res.add(new AdvancedReadTask(lowerBound.get + 1, 5, requiredSchema)) + res.add(new AdvancedReadTask(5, 10, requiredSchema)) + } else if (lowerBound.get < 9) { + res.add(new AdvancedReadTask(lowerBound.get + 1, 10, requiredSchema)) + } + + res + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class AdvancedReadTask(start: Int, end: Int, requiredSchema: StructType) + extends ReadTask[Row] with DataReader[Row] { + + private var current = start - 1 + + override def createReader(): DataReader[Row] = new AdvancedReadTask(start, end, requiredSchema) + + override def close(): Unit = {} + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): Row = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current + } + Row.fromSeq(values) + } +} + + +class UnsafeRowDataSourceV2 extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsScanUnsafeRow { + override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + + override def createUnsafeRowReadTasks(): JList[ReadTask[UnsafeRow]] = { + java.util.Arrays.asList(new UnsafeRowReadTask(0, 5), new UnsafeRowReadTask(5, 10)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class UnsafeRowReadTask(start: Int, end: Int) + extends ReadTask[UnsafeRow] with DataReader[UnsafeRow] { + + private val row = new UnsafeRow(2) + row.pointTo(new Array[Byte](8 * 3), 8 * 3) + + private var current = start - 1 + + override def createReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end) + + override def next(): Boolean = { + current += 1 + current < end + } + override def get(): UnsafeRow = { + row.setInt(0, current) + row.setInt(1, -current) + row + } + + override def close(): Unit = {} +} + +class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema { + + class Reader(val readSchema: StructType) extends DataSourceV2Reader { + override def createReadTasks(): JList[ReadTask[Row]] = + java.util.Collections.emptyList() + } + + override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader = + new Reader(schema) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala new file mode 100644 index 0000000000000..66c0263e872b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.util.UUID + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} +import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} +import org.apache.spark.sql.test.SharedSQLContext + +class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { + + import testImplicits._ + super.beforeAll() + + private val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") + + testEnsureStatefulOpPartitioning( + "ClusteredDistribution generates Exchange with HashPartitioning", + baseDf.queryExecution.sparkPlan, + requiredDistribution = keys => ClusteredDistribution(keys), + expectedPartitioning = + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning", + baseDf.coalesce(1).queryExecution.sparkPlan, + requiredDistribution = keys => ClusteredDistribution(keys), + expectedPartitioning = + keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "AllTuples generates Exchange with SinglePartition", + baseDf.queryExecution.sparkPlan, + requiredDistribution = _ => AllTuples, + expectedPartitioning = _ => SinglePartition, + expectShuffle = true) + + testEnsureStatefulOpPartitioning( + "AllTuples with coalesce(1) doesn't need Exchange", + baseDf.coalesce(1).queryExecution.sparkPlan, + requiredDistribution = _ => AllTuples, + expectedPartitioning = _ => SinglePartition, + expectShuffle = false) + + /** + * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan + * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to + * ensure the expected partitioning. + */ + private def testEnsureStatefulOpPartitioning( + testName: String, + inputPlan: SparkPlan, + requiredDistribution: Seq[Attribute] => Distribution, + expectedPartitioning: Seq[Attribute] => Partitioning, + expectShuffle: Boolean): Unit = { + test(testName) { + val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) + val executed = executePlan(operator, OutputMode.Complete()) + if (expectShuffle) { + val exchange = executed.children.find(_.isInstanceOf[Exchange]) + if (exchange.isEmpty) { + fail(s"Was expecting an exchange but didn't get one in:\n$executed") + } + assert(exchange.get === + ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan), + s"Exchange didn't have expected properties:\n${exchange.get}") + } else { + assert(!executed.children.exists(_.isInstanceOf[Exchange]), + s"Unexpected exchange found in:\n$executed") + } + } + } + + /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ + private def executePlan( + p: SparkPlan, + outputMode: OutputMode = OutputMode.Append()): SparkPlan = { + val execution = new IncrementalExecution( + spark, + null, + OutputMode.Complete(), + "chk", + UUID.randomUUID(), + 0L, + OffsetSeqMetadata()) { + override lazy val sparkPlan: SparkPlan = p transform { + case plan: SparkPlan => + val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap + plan transformExpressions { + case UnresolvedAttribute(Seq(u)) => + inputMap.getOrElse(u, + sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) + } + } + } + execution.executedPlan + } +} + +/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ +case class TestStatefulOperator( + child: SparkPlan, + requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { + override def output: Seq[Attribute] = child.output + override def doExecute(): RDD[InternalRow] = child.execute() + override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil + override def stateInfo: Option[StatefulOperatorStateInfo] = None +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 4f19fa0bb4a97..f3e8cf950a5a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -300,6 +300,84 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche ) } + test("watermark with 2 streams") { + import org.apache.spark.sql.functions.sum + val first = MemoryStream[Int] + + val firstDf = first.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "10 seconds") + .select('value) + + val second = MemoryStream[Int] + + val secondDf = second.toDF() + .withColumn("eventTime", $"value".cast("timestamp")) + .withWatermark("eventTime", "5 seconds") + .select('value) + + withTempDir { checkpointDir => + val unionWriter = firstDf.union(secondDf).agg(sum('value)) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("memory") + .outputMode("complete") + .queryName("test") + + val union = unionWriter.start() + + def getWatermarkAfterData( + firstData: Seq[Int] = Seq.empty, + secondData: Seq[Int] = Seq.empty, + query: StreamingQuery = union): Long = { + if (firstData.nonEmpty) first.addData(firstData) + if (secondData.nonEmpty) second.addData(secondData) + query.processAllAvailable() + // add a dummy batch so lastExecution has the new watermark + first.addData(0) + query.processAllAvailable() + // get last watermark + val lastExecution = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastExecution + lastExecution.offsetSeqMetadata.batchWatermarkMs + } + + // Global watermark starts at 0 until we get data from both sides + assert(getWatermarkAfterData(firstData = Seq(11)) == 0) + assert(getWatermarkAfterData(secondData = Seq(6)) == 1000) + // Global watermark stays at left watermark 1 when right watermark moves to 2 + assert(getWatermarkAfterData(secondData = Seq(8)) == 1000) + // Global watermark switches to right side value 2 when left watermark goes higher + assert(getWatermarkAfterData(firstData = Seq(21)) == 3000) + // Global watermark goes back to left + assert(getWatermarkAfterData(secondData = Seq(17, 28, 39)) == 11000) + // Global watermark stays on left as long as it's below right + assert(getWatermarkAfterData(firstData = Seq(31)) == 21000) + assert(getWatermarkAfterData(firstData = Seq(41)) == 31000) + // Global watermark switches back to right again + assert(getWatermarkAfterData(firstData = Seq(51)) == 34000) + + // Global watermark is updated correctly with simultaneous data from both sides + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(100)) == 90000) + assert(getWatermarkAfterData(firstData = Seq(120), secondData = Seq(110)) == 105000) + assert(getWatermarkAfterData(firstData = Seq(130), secondData = Seq(125)) == 120000) + + // Global watermark doesn't decrement with simultaneous data + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(100)) == 120000) + assert(getWatermarkAfterData(firstData = Seq(140), secondData = Seq(100)) == 120000) + assert(getWatermarkAfterData(firstData = Seq(100), secondData = Seq(135)) == 130000) + + // Global watermark recovers after restart, but left side watermark ahead of it does not. + assert(getWatermarkAfterData(firstData = Seq(200), secondData = Seq(190)) == 185000) + union.stop() + val union2 = unionWriter.start() + assert(getWatermarkAfterData(query = union2) == 185000) + // Even though the left side was ahead of 185000 in the last execution, the watermark won't + // increment until it gets past it in this execution. + assert(getWatermarkAfterData(secondData = Seq(200), query = union2) == 185000) + assert(getWatermarkAfterData(firstData = Seq(200), query = union2) == 190000) + } + } + test("complete mode") { val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 4f8764060d922..70b39b934071a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -167,7 +167,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class StartStream( trigger: Trigger = Trigger.ProcessingTime(0), triggerClock: Clock = new SystemClock, - additionalConfs: Map[String, String] = Map.empty) + additionalConfs: Map[String, String] = Map.empty, + checkpointLocation: String = null) extends StreamAction /** Advance the trigger clock's time manually. */ @@ -349,13 +350,14 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be """.stripMargin) } - val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath var manualClockExpectedTime = -1L + val defaultCheckpointLocation = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath try { startedTest.foreach { action => logInfo(s"Processing test stream action: $action") action match { - case StartStream(trigger, triggerClock, additionalConfs) => + case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) => verify(currentStream == null, "stream already running") verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], @@ -363,6 +365,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be if (triggerClock.isInstanceOf[StreamManualClock]) { manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() } + val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) additionalConfs.foreach(pair => { val value = @@ -479,7 +482,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be verify(currentStream != null || lastStream != null, "cannot assert when no stream has been started") val streamToAssert = Option(currentStream).getOrElse(lastStream) - verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + try { + verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") + } catch { + case NonFatal(e) => + failTest(s"Assert on query failed: ${a.message}", e) + } case a: Assert => val streamToAssert = Option(currentStream).getOrElse(lastStream) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index e0979ce296c3a..995cea3b37d4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -22,20 +22,24 @@ import java.util.{Locale, TimeZone} import org.scalatest.Assertions import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, DataFrame} +import org.apache.spark.{SparkEnv, SparkException} +import org.apache.spark.rdd.BlockRDD +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.OutputMode._ -import org.apache.spark.sql.streaming.util.StreamManualClock +import org.apache.spark.sql.streaming.util.{MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} -object FailureSinglton { +object FailureSingleton { var firstTime = true } @@ -226,12 +230,12 @@ class StreamingAggregationSuite extends StateStoreMetricsTest testQuietly("midbatch failure") { val inputData = MemoryStream[Int] - FailureSinglton.firstTime = true + FailureSingleton.firstTime = true val aggregated = inputData.toDS() .map { i => - if (i == 4 && FailureSinglton.firstTime) { - FailureSinglton.firstTime = false + if (i == 4 && FailureSingleton.firstTime) { + FailureSingleton.firstTime = false sys.error("injected failure") } @@ -381,4 +385,180 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(streamInput, 0, 1, 2, 3), CheckLastBatch((0, 0, 2), (1, 1, 3))) } + + /** + * This method verifies certain properties in the SparkPlan of a streaming aggregation. + * First of all, it checks that the child of a `StateStoreRestoreExec` creates the desired + * data distribution, where the child could be an Exchange, or a `HashAggregateExec` which already + * provides the expected data distribution. + * + * The second thing it checks that the child provides the expected number of partitions. + * + * The third thing it checks that we don't add an unnecessary shuffle in-between + * `StateStoreRestoreExec` and `StateStoreSaveExec`. + */ + private def checkAggregationChain( + se: StreamExecution, + expectShuffling: Boolean, + expectedPartition: Int): Boolean = { + val executedPlan = se.lastExecution.executedPlan + val restore = executedPlan + .collect { case ss: StateStoreRestoreExec => ss } + .head + restore.child match { + case node: UnaryExecNode => + assert(node.outputPartitioning.numPartitions === expectedPartition, + "Didn't get the expected number of partitions.") + if (expectShuffling) { + assert(node.isInstanceOf[Exchange], s"Expected a shuffle, got: ${node.child}") + } else { + assert(!node.isInstanceOf[Exchange], "Didn't expect a shuffle") + } + + case _ => fail("Expected no shuffling") + } + var reachedRestore = false + // Check that there should be no exchanges after `StateStoreRestoreExec` + executedPlan.foreachUp { p => + if (reachedRestore) { + assert(!p.isInstanceOf[Exchange], "There should be no further exchanges") + } else { + reachedRestore = p.isInstanceOf[StateStoreRestoreExec] + } + } + true + } + + test("SPARK-21977: coalesce(1) with 0 partition RDD should be repartitioned to 1") { + val inputSource = new BlockRDDBackedSource(spark) + MockSourceProvider.withMockSources(inputSource) { + // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default + // satisfies the required distributions of all aggregations. Therefore in our SparkPlan, we + // don't have any shuffling. However, `coalesce(1)` only guarantees that the RDD has at most 1 + // partition. Which means that if we have an input RDD with 0 partitions, nothing gets + // executed. Therefore the StateStore's don't save any delta files for a given trigger. This + // then leads to `FileNotFoundException`s in the subsequent batch. + // This isn't the only problem though. Once we introduce a shuffle before + // `StateStoreRestoreExec`, the input to the operator is an empty iterator. When performing + // `groupBy().agg(...)`, `HashAggregateExec` returns a `0` value for all aggregations. If + // we fail to restore the previous state in `StateStoreRestoreExec`, we save the 0 value in + // `StateStoreSaveExec` losing all previous state. + val aggregated: Dataset[Long] = + spark.readStream.format((new MockSourceProvider).getClass.getCanonicalName) + .load().coalesce(1).groupBy().count().as[Long] + + testStream(aggregated, Complete())( + AddBlockData(inputSource, Seq(1)), + CheckLastBatch(1), + AssertOnQuery("Verify no shuffling") { se => + checkAggregationChain(se, expectShuffling = false, 1) + }, + AddBlockData(inputSource), // create an empty trigger + CheckLastBatch(1), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain(se, expectShuffling = true, 1) + }, + AddBlockData(inputSource, Seq(2, 3)), + CheckLastBatch(3), + AddBlockData(inputSource), + CheckLastBatch(3), + StopStream + ) + } + } + + test("SPARK-21977: coalesce(1) with aggregation should still be repartitioned when it " + + "has non-empty grouping keys") { + val inputSource = new BlockRDDBackedSource(spark) + MockSourceProvider.withMockSources(inputSource) { + withTempDir { tempDir => + + // `coalesce(1)` changes the partitioning of data to `SinglePartition` which by default + // satisfies the required distributions of all aggregations. However, when we have + // non-empty grouping keys, in streaming, we must repartition to + // `spark.sql.shuffle.partitions`, otherwise only a single StateStore is used to process + // all keys. This may be fine, however, if the user removes the coalesce(1) or changes to + // a `coalesce(2)` for example, then the default behavior is to shuffle to + // `spark.sql.shuffle.partitions` many StateStores. When this happens, all StateStore's + // except 1 will be missing their previous delta files, which causes the stream to fail + // with FileNotFoundException. + def createDf(partitions: Int): Dataset[(Long, Long)] = { + spark.readStream + .format((new MockSourceProvider).getClass.getCanonicalName) + .load().coalesce(partitions).groupBy('a % 1).count().as[(Long, Long)] + } + + testStream(createDf(1), Complete())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddBlockData(inputSource, Seq(1)), + CheckLastBatch((0L, 1L)), + AssertOnQuery("Verify addition of exchange operator") { se => + checkAggregationChain( + se, + expectShuffling = true, + spark.sessionState.conf.numShufflePartitions) + }, + StopStream + ) + + testStream(createDf(2), Complete())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + Execute(se => se.processAllAvailable()), + AddBlockData(inputSource, Seq(2), Seq(3), Seq(4)), + CheckLastBatch((0L, 4L)), + AssertOnQuery("Verify no exchange added") { se => + checkAggregationChain( + se, + expectShuffling = false, + spark.sessionState.conf.numShufflePartitions) + }, + AddBlockData(inputSource), + CheckLastBatch((0L, 4L)), + StopStream + ) + } + } + } + + /** Add blocks of data to the `BlockRDDBackedSource`. */ + case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { + override def addData(query: Option[StreamExecution]): (Source, Offset) = { + source.addBlocks(data: _*) + (source, LongOffset(source.counter)) + } + } + + /** + * A Streaming Source that is backed by a BlockRDD and that can create RDDs with 0 blocks at will. + */ + class BlockRDDBackedSource(spark: SparkSession) extends Source { + var counter = 0L + private val blockMgr = SparkEnv.get.blockManager + private var blocks: Seq[BlockId] = Seq.empty + + def addBlocks(dataBlocks: Seq[Int]*): Unit = synchronized { + dataBlocks.foreach { data => + val id = TestBlockId(counter.toString) + blockMgr.putIterator(id, data.iterator, StorageLevel.MEMORY_ONLY) + blocks ++= id :: Nil + counter += 1 + } + counter += 1 + } + + override def getOffset: Option[Offset] = synchronized { + if (counter == 0) None else Some(LongOffset(counter)) + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = synchronized { + val rdd = new BlockRDD[Int](spark.sparkContext, blocks.toArray) + .map(i => InternalRow(i)) // we don't really care about the values in this test + blocks = Seq.empty + spark.internalCreateDataFrame(rdd, schema, isStreaming = true).toDF() + } + override def schema: StructType = MockSourceProvider.fakeSchema + override def stop(): Unit = { + blockMgr.getMatchingBlockIds(_.isInstanceOf[TestBlockId]).foreach(blockMgr.removeBlock(_)) + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 761e832ed14b8..832a15d09599f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -37,6 +37,8 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.log4j.{Level, Logger} import org.apache.thrift.transport.TSocket +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveUtils @@ -81,11 +83,17 @@ private[hive] object SparkSQLCLIDriver extends Logging { System.exit(1) } + val sparkConf = new SparkConf(loadDefaults = true) + val hadoopConf = SparkHadoopUtil.get.newConfiguration(sparkConf) + val extraConfigs = HiveUtils.formatTimeVarsForHiveClient(hadoopConf) + val cliConf = new HiveConf(classOf[SessionState]) - // Override the location of the metastore since this is only used for local execution. - HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false).foreach { - case (key, value) => cliConf.set(key, value) + (hadoopConf.iterator().asScala.map(kv => kv.getKey -> kv.getValue) + ++ sparkConf.getAll.toMap ++ extraConfigs).foreach { + case (k, v) => + cliConf.set(k, v) } + val sessionState = new CliSessionState(cliConf) sessionState.in = System.in diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 561c127a40bb6..80b9a3dc9605d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -176,9 +176,9 @@ private[spark] object HiveUtils extends Logging { } /** - * Configurations needed to create a [[HiveClient]]. + * Change time configurations needed to create a [[HiveClient]] into unified [[Long]] format. */ - private[hive] def hiveClientConfigurations(hadoopConf: Configuration): Map[String, String] = { + private[hive] def formatTimeVarsForHiveClient(hadoopConf: Configuration): Map[String, String] = { // Hive 0.14.0 introduces timeout operations in HiveConf, and changes default values of a bunch // of time `ConfVar`s by adding time suffixes (`s`, `ms`, and `d` etc.). This breaks backwards- // compatibility when users are trying to connecting to a Hive metastore of lower version, @@ -280,7 +280,7 @@ private[spark] object HiveUtils extends Logging { protected[hive] def newClientForMetadata( conf: SparkConf, hadoopConf: Configuration): HiveClient = { - val configurations = hiveClientConfigurations(hadoopConf) + val configurations = formatTimeVarsForHiveClient(hadoopConf) newClientForMetadata(conf, hadoopConf, configurations) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 426db6a4e1c12..c4e48c9360db7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Order} import org.apache.hadoop.hive.metastore.api.{SerDeInfo, StorageDescriptor} @@ -132,14 +133,24 @@ private[hive] class HiveClientImpl( // in hive jars, which will turn off isolation, if SessionSate.detachSession is // called to remove the current state after that, hive client created later will initialize // its own state by newState() - Option(SessionState.get).getOrElse(newState()) + val ret = SessionState.get + if (ret != null) { + // hive.metastore.warehouse.dir is determined in SharedState after the CliSessionState + // instance constructed, we need to follow that change here. + Option(hadoopConf.get(ConfVars.METASTOREWAREHOUSE.varname)).foreach { dir => + ret.getConf.setVar(ConfVars.METASTOREWAREHOUSE, dir) + } + ret + } else { + newState() + } } } // Log the default warehouse location. logInfo( s"Warehouse location for Hive client " + - s"(version ${version.fullVersion}) is ${conf.get("hive.metastore.warehouse.dir")}") + s"(version ${version.fullVersion}) is ${conf.getVar(ConfVars.METASTOREWAREHOUSE)}") private def newState(): SessionState = { val hiveConf = new HiveConf(classOf[SessionState]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 2928a734a7e36..305f5b533d592 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -38,12 +38,15 @@ import org.apache.spark.util.Utils class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse") private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data") - private val sparkTestingDir = "/tmp/spark-test" + // For local test, you can set `sparkTestingDir` to a static value like `/tmp/test-spark`, to + // avoid downloading Spark of different versions in each run. + private val sparkTestingDir = Utils.createTempDir(namePrefix = "test-spark") private val unusedJar = TestUtils.createJarWithClasses(Seq.empty) override def afterAll(): Unit = { Utils.deleteRecursively(wareHousePath) Utils.deleteRecursively(tmpDataDir) + Utils.deleteRecursively(sparkTestingDir) super.afterAll() } @@ -52,7 +55,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz" - Seq("wget", url, "-q", "-P", sparkTestingDir).! + Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala index ed475a0261b0b..951ebfad4590e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala @@ -36,7 +36,7 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu hadoopConf.set("hive.metastore.schema.verification", "false") } HiveClientBuilder - .buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + .buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) } override def suiteName: String = s"${super.suiteName}($version)" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 1d9c8da996fea..edb9a9ffbaaf6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -127,7 +127,7 @@ class VersionsSuite extends SparkFunSuite with Logging { hadoopConf.set("datanucleus.schema.autoCreateAll", "true") hadoopConf.set("hive.metastore.schema.verification", "false") } - client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) + client = buildClient(version, hadoopConf, HiveUtils.formatTimeVarsForHiveClient(hadoopConf)) if (versionSpark != null) versionSpark.reset() versionSpark = TestHiveVersion(client) assert(versionSpark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index f3b4ff2d1d80c..8c7418ec7ac10 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -659,8 +659,7 @@ class StreamingContext private[streaming] ( def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = { var shutdownHookRefToRemove: AnyRef = null if (LiveListenerBus.withinListenerThread.value) { - throw new SparkException( - s"Cannot stop StreamingContext within listener thread of ${LiveListenerBus.name}") + throw new SparkException(s"Cannot stop StreamingContext within listener bus thread.") } synchronized { // The state should always be Stopped after calling `stop()`, even if we haven't started yet diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 5fb0bd057d0f1..6a70bf7406b3c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -76,7 +76,7 @@ private[streaming] class StreamingListenerBus(sparkListenerBus: LiveListenerBus) * forward them to StreamingListeners. */ def start(): Unit = { - sparkListenerBus.addListener(this) // for getting callbacks on spark events + sparkListenerBus.addToStatusQueue(this) } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 96ab5a2080b8e..5810e73f4098b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -575,8 +575,6 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL test("getActive and getActiveOrCreate") { require(StreamingContext.getActive().isEmpty, "context exists from before") - sc = new SparkContext(conf) - var newContextCreated = false def creatingFunc(): StreamingContext = { @@ -603,6 +601,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // getActiveOrCreate should create new context and getActive should return it only // after starting the context testGetActiveOrCreate { + sc = new SparkContext(conf) ssc = StreamingContext.getActiveOrCreate(creatingFunc _) assert(ssc != null, "no context created") assert(newContextCreated === true, "new context not created") @@ -622,6 +621,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeL // getActiveOrCreate and getActive should return independently created context after activating testGetActiveOrCreate { + sc = new SparkContext(conf) ssc = creatingFunc() // Create assert(StreamingContext.getActive().isEmpty, "new initialized context returned before starting")