Skip to content

Commit 92c2f00

Browse files
mn-mikkeueshin
authored andcommitted
[SPARK-23934][SQL] Adding map_from_entries function
## What changes were proposed in this pull request? The PR adds the `map_from_entries` function that returns a map created from the given array of entries. ## How was this patch tested? New tests added into: - `CollectionExpressionSuite` - `DataFrameFunctionSuite` ## CodeGen Examples ### Primitive-type Keys and Values ``` val idf = Seq( Seq((1, 10), (2, 20), (3, 10)), Seq((1, 10), null, (2, 20)) ).toDF("a") idf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ MapData project_value_0 = null; /* 044 */ /* 045 */ for (int project_idx_2 = 0; !project_isNull_0 && project_idx_2 < inputadapter_value_0.numElements(); project_idx_2++) { /* 046 */ project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_2); /* 047 */ } /* 048 */ if (!project_isNull_0) { /* 049 */ final int project_numEntries_0 = inputadapter_value_0.numElements(); /* 050 */ /* 051 */ final long project_keySectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0, 4); /* 052 */ final long project_valueSectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0, 4); /* 053 */ final long project_byteArraySize_0 = 8 + project_keySectionSize_0 + project_valueSectionSize_0; /* 054 */ if (project_byteArraySize_0 > 2147483632) { /* 055 */ final Object[] project_keys_0 = new Object[project_numEntries_0]; /* 056 */ final Object[] project_values_0 = new Object[project_numEntries_0]; /* 057 */ /* 058 */ for (int project_idx_1 = 0; project_idx_1 < project_numEntries_0; project_idx_1++) { /* 059 */ InternalRow project_entry_1 = inputadapter_value_0.getStruct(project_idx_1, 2); /* 060 */ /* 061 */ project_keys_0[project_idx_1] = project_entry_1.getInt(0); /* 062 */ project_values_0[project_idx_1] = project_entry_1.getInt(1); /* 063 */ } /* 064 */ /* 065 */ project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0, project_values_0); /* 066 */ /* 067 */ } else { /* 068 */ final byte[] project_byteArray_0 = new byte[(int)project_byteArraySize_0]; /* 069 */ UnsafeMapData project_unsafeMapData_0 = new UnsafeMapData(); /* 070 */ Platform.putLong(project_byteArray_0, 16, project_keySectionSize_0); /* 071 */ Platform.putLong(project_byteArray_0, 24, project_numEntries_0); /* 072 */ Platform.putLong(project_byteArray_0, 24 + project_keySectionSize_0, project_numEntries_0); /* 073 */ project_unsafeMapData_0.pointTo(project_byteArray_0, 16, (int)project_byteArraySize_0); /* 074 */ ArrayData project_keyArrayData_0 = project_unsafeMapData_0.keyArray(); /* 075 */ ArrayData project_valueArrayData_0 = project_unsafeMapData_0.valueArray(); /* 076 */ /* 077 */ for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++) { /* 078 */ InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0, 2); /* 079 */ /* 080 */ project_keyArrayData_0.setInt(project_idx_0, project_entry_0.getInt(0)); /* 081 */ project_valueArrayData_0.setInt(project_idx_0, project_entry_0.getInt(1)); /* 082 */ } /* 083 */ /* 084 */ project_value_0 = project_unsafeMapData_0; /* 085 */ } /* 086 */ /* 087 */ } ``` ### Non-primitive-type Keys and Values ``` val sdf = Seq( Seq(("a", null), ("b", "bb"), ("c", "aa")), Seq(("a", "aa"), null, (null, "bb")) ).toDF("a") sdf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ MapData project_value_0 = null; /* 044 */ /* 045 */ for (int project_idx_1 = 0; !project_isNull_0 && project_idx_1 < inputadapter_value_0.numElements(); project_idx_1++) { /* 046 */ project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_1); /* 047 */ } /* 048 */ if (!project_isNull_0) { /* 049 */ final int project_numEntries_0 = inputadapter_value_0.numElements(); /* 050 */ /* 051 */ final Object[] project_keys_0 = new Object[project_numEntries_0]; /* 052 */ final Object[] project_values_0 = new Object[project_numEntries_0]; /* 053 */ /* 054 */ for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++) { /* 055 */ InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0, 2); /* 056 */ /* 057 */ if (project_entry_0.isNullAt(0)) { /* 058 */ throw new RuntimeException("The first field from a struct (key) can't be null."); /* 059 */ } /* 060 */ /* 061 */ project_keys_0[project_idx_0] = project_entry_0.getUTF8String(0); /* 062 */ project_values_0[project_idx_0] = project_entry_0.getUTF8String(1); /* 063 */ } /* 064 */ /* 065 */ project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0, project_values_0); /* 066 */ /* 067 */ } ``` Author: Marek Novotny <[email protected]> Closes apache#21282 from mn-mikke/feature/array-api-map_from_entries-to-master.
1 parent dc8a6be commit 92c2f00

File tree

7 files changed

+378
-16
lines changed

7 files changed

+378
-16
lines changed

python/pyspark/sql/functions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2412,6 +2412,26 @@ def map_entries(col):
24122412
return Column(sc._jvm.functions.map_entries(_to_java_column(col)))
24132413

24142414

2415+
@since(2.4)
2416+
def map_from_entries(col):
2417+
"""
2418+
Collection function: Returns a map created from the given array of entries.
2419+
2420+
:param col: name of column or expression
2421+
2422+
>>> from pyspark.sql.functions import map_from_entries
2423+
>>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data")
2424+
>>> df.select(map_from_entries("data").alias("map")).show()
2425+
+----------------+
2426+
| map|
2427+
+----------------+
2428+
|[1 -> a, 2 -> b]|
2429+
+----------------+
2430+
"""
2431+
sc = SparkContext._active_spark_context
2432+
return Column(sc._jvm.functions.map_from_entries(_to_java_column(col)))
2433+
2434+
24152435
@ignore_unicode_prefix
24162436
@since(2.4)
24172437
def array_repeat(col, count):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ object FunctionRegistry {
421421
expression[MapKeys]("map_keys"),
422422
expression[MapValues]("map_values"),
423423
expression[MapEntries]("map_entries"),
424+
expression[MapFromEntries]("map_from_entries"),
424425
expression[Size]("size"),
425426
expression[Slice]("slice"),
426427
expression[Size]("cardinality"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,36 @@ class CodegenContext {
819819
}
820820
}
821821

822+
/**
823+
* Generates code to do null safe execution when accessing properties of complex
824+
* ArrayData elements.
825+
*
826+
* @param nullElements used to decide whether the ArrayData might contain null or not.
827+
* @param isNull a variable indicating whether the result will be evaluated to null or not.
828+
* @param arrayData a variable name representing the ArrayData.
829+
* @param execute the code that should be executed only if the ArrayData doesn't contain
830+
* any null.
831+
*/
832+
def nullArrayElementsSaveExec(
833+
nullElements: Boolean,
834+
isNull: String,
835+
arrayData: String)(
836+
execute: String): String = {
837+
val i = freshName("idx")
838+
if (nullElements) {
839+
s"""
840+
|for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) {
841+
| $isNull |= $arrayData.isNullAt($i);
842+
|}
843+
|if (!$isNull) {
844+
| $execute
845+
|}
846+
""".stripMargin
847+
} else {
848+
execute
849+
}
850+
}
851+
822852
/**
823853
* Splits the generated code of expressions into multiple functions, because function has
824854
* 64kb code size limit in JVM. If the class to which the function would be inlined would grow

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 219 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
2525
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
28-
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
28+
import org.apache.spark.sql.catalyst.util._
2929
import org.apache.spark.sql.internal.SQLConf
3030
import org.apache.spark.sql.types._
3131
import org.apache.spark.unsafe.Platform
@@ -475,6 +475,223 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
475475
override def prettyName: String = "map_entries"
476476
}
477477

478+
/**
479+
* Returns a map created from the given array of entries.
480+
*/
481+
@ExpressionDescription(
482+
usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.",
483+
examples = """
484+
Examples:
485+
> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
486+
{1:"a",2:"b"}
487+
""",
488+
since = "2.4.0")
489+
case class MapFromEntries(child: Expression) extends UnaryExpression {
490+
491+
@transient
492+
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
493+
case ArrayType(
494+
StructType(Array(
495+
StructField(_, keyType, keyNullable, _),
496+
StructField(_, valueType, valueNullable, _))),
497+
containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull))
498+
case _ => None
499+
}
500+
501+
private def nullEntries: Boolean = dataTypeDetails.get._3
502+
503+
override def nullable: Boolean = child.nullable || nullEntries
504+
505+
override def dataType: MapType = dataTypeDetails.get._1
506+
507+
override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
508+
case Some(_) => TypeCheckResult.TypeCheckSuccess
509+
case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " +
510+
s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.")
511+
}
512+
513+
override protected def nullSafeEval(input: Any): Any = {
514+
val arrayData = input.asInstanceOf[ArrayData]
515+
val numEntries = arrayData.numElements()
516+
var i = 0
517+
if(nullEntries) {
518+
while (i < numEntries) {
519+
if (arrayData.isNullAt(i)) return null
520+
i += 1
521+
}
522+
}
523+
val keyArray = new Array[AnyRef](numEntries)
524+
val valueArray = new Array[AnyRef](numEntries)
525+
i = 0
526+
while (i < numEntries) {
527+
val entry = arrayData.getStruct(i, 2)
528+
val key = entry.get(0, dataType.keyType)
529+
if (key == null) {
530+
throw new RuntimeException("The first field from a struct (key) can't be null.")
531+
}
532+
keyArray.update(i, key)
533+
val value = entry.get(1, dataType.valueType)
534+
valueArray.update(i, value)
535+
i += 1
536+
}
537+
ArrayBasedMapData(keyArray, valueArray)
538+
}
539+
540+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
541+
nullSafeCodeGen(ctx, ev, c => {
542+
val numEntries = ctx.freshName("numEntries")
543+
val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType)
544+
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
545+
val code = if (isKeyPrimitive && isValuePrimitive) {
546+
genCodeForPrimitiveElements(ctx, c, ev.value, numEntries)
547+
} else {
548+
genCodeForAnyElements(ctx, c, ev.value, numEntries)
549+
}
550+
ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) {
551+
s"""
552+
|final int $numEntries = $c.numElements();
553+
|$code
554+
""".stripMargin
555+
}
556+
})
557+
}
558+
559+
private def genCodeForAssignmentLoop(
560+
ctx: CodegenContext,
561+
childVariable: String,
562+
mapData: String,
563+
numEntries: String,
564+
keyAssignment: (String, String) => String,
565+
valueAssignment: (String, String) => String): String = {
566+
val entry = ctx.freshName("entry")
567+
val i = ctx.freshName("idx")
568+
569+
val nullKeyCheck = if (dataTypeDetails.get._2) {
570+
s"""
571+
|if ($entry.isNullAt(0)) {
572+
| throw new RuntimeException("The first field from a struct (key) can't be null.");
573+
|}
574+
""".stripMargin
575+
} else {
576+
""
577+
}
578+
579+
s"""
580+
|for (int $i = 0; $i < $numEntries; $i++) {
581+
| InternalRow $entry = $childVariable.getStruct($i, 2);
582+
| $nullKeyCheck
583+
| ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)}
584+
| ${valueAssignment(entry, i)}
585+
|}
586+
""".stripMargin
587+
}
588+
589+
private def genCodeForPrimitiveElements(
590+
ctx: CodegenContext,
591+
childVariable: String,
592+
mapData: String,
593+
numEntries: String): String = {
594+
val byteArraySize = ctx.freshName("byteArraySize")
595+
val keySectionSize = ctx.freshName("keySectionSize")
596+
val valueSectionSize = ctx.freshName("valueSectionSize")
597+
val data = ctx.freshName("byteArray")
598+
val unsafeMapData = ctx.freshName("unsafeMapData")
599+
val keyArrayData = ctx.freshName("keyArrayData")
600+
val valueArrayData = ctx.freshName("valueArrayData")
601+
602+
val baseOffset = Platform.BYTE_ARRAY_OFFSET
603+
val keySize = dataType.keyType.defaultSize
604+
val valueSize = dataType.valueType.defaultSize
605+
val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)"
606+
val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)"
607+
val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType)
608+
val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType)
609+
610+
val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);"
611+
val valueAssignment = (entry: String, idx: String) => {
612+
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
613+
val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);"
614+
if (dataType.valueContainsNull) {
615+
s"""
616+
|if ($entry.isNullAt(1)) {
617+
| $valueArrayData.setNullAt($idx);
618+
|} else {
619+
| $valueNullUnsafeAssignment
620+
|}
621+
""".stripMargin
622+
} else {
623+
valueNullUnsafeAssignment
624+
}
625+
}
626+
val assignmentLoop = genCodeForAssignmentLoop(
627+
ctx,
628+
childVariable,
629+
mapData,
630+
numEntries,
631+
keyAssignment,
632+
valueAssignment
633+
)
634+
635+
s"""
636+
|final long $keySectionSize = $kByteSize;
637+
|final long $valueSectionSize = $vByteSize;
638+
|final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
639+
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
640+
| ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)}
641+
|} else {
642+
| final byte[] $data = new byte[(int)$byteArraySize];
643+
| UnsafeMapData $unsafeMapData = new UnsafeMapData();
644+
| Platform.putLong($data, $baseOffset, $keySectionSize);
645+
| Platform.putLong($data, ${baseOffset + 8}, $numEntries);
646+
| Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries);
647+
| $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
648+
| ArrayData $keyArrayData = $unsafeMapData.keyArray();
649+
| ArrayData $valueArrayData = $unsafeMapData.valueArray();
650+
| $assignmentLoop
651+
| $mapData = $unsafeMapData;
652+
|}
653+
""".stripMargin
654+
}
655+
656+
private def genCodeForAnyElements(
657+
ctx: CodegenContext,
658+
childVariable: String,
659+
mapData: String,
660+
numEntries: String): String = {
661+
val keys = ctx.freshName("keys")
662+
val values = ctx.freshName("values")
663+
val mapDataClass = classOf[ArrayBasedMapData].getName()
664+
665+
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
666+
val valueAssignment = (entry: String, idx: String) => {
667+
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
668+
if (dataType.valueContainsNull && isValuePrimitive) {
669+
s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;"
670+
} else {
671+
s"$values[$idx] = $value;"
672+
}
673+
}
674+
val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;"
675+
val assignmentLoop = genCodeForAssignmentLoop(
676+
ctx,
677+
childVariable,
678+
mapData,
679+
numEntries,
680+
keyAssignment,
681+
valueAssignment)
682+
683+
s"""
684+
|final Object[] $keys = new Object[$numEntries];
685+
|final Object[] $values = new Object[$numEntries];
686+
|$assignmentLoop
687+
|$mapData = $mapDataClass.apply($keys, $values);
688+
""".stripMargin
689+
}
690+
691+
override def prettyName: String = "map_from_entries"
692+
}
693+
694+
478695
/**
479696
* Common base class for [[SortArray]] and [[ArraySort]].
480697
*/
@@ -1990,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression {
19902207
} else {
19912208
genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
19922209
}
1993-
if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code
2210+
ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code)
19942211
})
19952212
}
19962213

1997-
private def nullElementsProtection(
1998-
ev: ExprCode,
1999-
childVariableName: String,
2000-
coreLogic: String): String = {
2001-
s"""
2002-
|for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) {
2003-
| ${ev.isNull} |= $childVariableName.isNullAt(z);
2004-
|}
2005-
|if (!${ev.isNull}) {
2006-
| $coreLogic
2007-
|}
2008-
""".stripMargin
2009-
}
2010-
20112214
private def genCodeForNumberOfElements(
20122215
ctx: CodegenContext,
20132216
childVariableName: String) : (String, String) = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
8080
checkEvaluation(MapEntries(ms2), null)
8181
}
8282

83+
test("MapFromEntries") {
84+
def arrayType(keyType: DataType, valueType: DataType) : DataType = {
85+
ArrayType(
86+
StructType(Seq(
87+
StructField("a", keyType),
88+
StructField("b", valueType))),
89+
true)
90+
}
91+
def r(values: Any*): InternalRow = create_row(values: _*)
92+
93+
// Primitive-type keys and values
94+
val aiType = arrayType(IntegerType, IntegerType)
95+
val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType)
96+
val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType)
97+
val ai2 = Literal.create(Seq.empty, aiType)
98+
val ai3 = Literal.create(null, aiType)
99+
val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType)
100+
val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType)
101+
val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType)
102+
103+
checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20))
104+
checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null))
105+
checkEvaluation(MapFromEntries(ai2), Map.empty)
106+
checkEvaluation(MapFromEntries(ai3), null)
107+
checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1))
108+
checkExceptionInExpression[RuntimeException](
109+
MapFromEntries(ai5),
110+
"The first field from a struct (key) can't be null.")
111+
checkEvaluation(MapFromEntries(ai6), null)
112+
113+
// Non-primitive-type keys and values
114+
val asType = arrayType(StringType, StringType)
115+
val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType)
116+
val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType)
117+
val as2 = Literal.create(Seq.empty, asType)
118+
val as3 = Literal.create(null, asType)
119+
val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType)
120+
val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType)
121+
val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType)
122+
123+
checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb"))
124+
checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null))
125+
checkEvaluation(MapFromEntries(as2), Map.empty)
126+
checkEvaluation(MapFromEntries(as3), null)
127+
checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a"))
128+
checkExceptionInExpression[RuntimeException](
129+
MapFromEntries(as5),
130+
"The first field from a struct (key) can't be null.")
131+
checkEvaluation(MapFromEntries(as6), null)
132+
}
133+
83134
test("Sort Array") {
84135
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
85136
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))

0 commit comments

Comments
 (0)