diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index b513b3858bbdd..f661bafc054bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -381,28 +381,74 @@ abstract class ExplodeBase extends UnaryExpression with CollectionGenerator with if (inputArray == null) { Nil } else { - val rows = new Array[InternalRow](inputArray.numElements()) - inputArray.foreach(et, (i, e) => { - rows(i) = if (position) InternalRow(i, e) else InternalRow(e) - }) - rows + new ArrayExplodeIterator(inputArray, et, position) } case MapType(kt, vt, _) => val inputMap = child.eval(input).asInstanceOf[MapData] if (inputMap == null) { Nil } else { - val rows = new Array[InternalRow](inputMap.numElements()) - var i = 0 - inputMap.foreach(kt, vt, (k, v) => { - rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v) - i += 1 - }) - rows + new MapExplodeIterator(inputMap, kt, vt, position) } } } + private class ArrayExplodeIterator( + array: ArrayData, + elementType: DataType, + includePosition: Boolean) + extends IterableOnce[InternalRow] { + + override def iterator: Iterator[InternalRow] = new Iterator[InternalRow] { + private var currentIndex = 0 + private val numElements = array.numElements() + + override def hasNext: Boolean = currentIndex < numElements + + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException("No more elements") + val element = array.get(currentIndex, elementType) + val row = if (includePosition) { + InternalRow(currentIndex, element) + } else { + InternalRow(element) + } + currentIndex += 1 + row + } + } + } + + private class MapExplodeIterator( + mapData: MapData, + keyType: DataType, + valueType: DataType, + includePosition: Boolean) + extends IterableOnce[InternalRow] { + + override def iterator: Iterator[InternalRow] = new Iterator[InternalRow] { + private var currentIndex = 0 + private val numElements = mapData.numElements() + private val keyArray = mapData.keyArray() + private val valueArray = mapData.valueArray() + + override def hasNext: Boolean = currentIndex < numElements + + override def next(): InternalRow = { + if (!hasNext) throw new NoSuchElementException("No more elements") + val key = keyArray.get(currentIndex, keyType) + val value = valueArray.get(currentIndex, valueType) + val row = if (includePosition) { + InternalRow(currentIndex, key, value) + } else { + InternalRow(key, value) + } + currentIndex += 1 + row + } + } + } + override def collectionType: DataType = child.dataType override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {