Skip to content

Commit 99b899f

Browse files
committed
comments and style
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
1 parent a3d0e70 commit 99b899f

File tree

4 files changed

+69
-14
lines changed

4 files changed

+69
-14
lines changed

integration_tests/src/main/python/protobuf_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,31 @@ def run_on_spark(spark):
16071607
assert_gpu_and_cpu_error(run_on_spark, conf={}, error_message="Malformed")
16081608

16091609

1610+
@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
1611+
@ignore_order(local=True)
1612+
def test_from_protobuf_required_field_missing_permissive(spark_tmp_path, from_protobuf_fn):
1613+
"""Required-field violations should null the whole row in PERMISSIVE mode."""
1614+
desc_path, desc_bytes = _setup_protobuf_desc(
1615+
spark_tmp_path, "required.desc", _build_required_field_descriptor_set_bytes)
1616+
message_name = "test.WithRequired"
1617+
1618+
missing_required_row = _encode_tag(2, 2) + _encode_varint(5) + b"hello"
1619+
1620+
def run_on_spark(spark):
1621+
df = spark.createDataFrame([(missing_required_row,)], schema="bin binary")
1622+
decoded = _call_from_protobuf(
1623+
from_protobuf_fn, f.col("bin"), message_name, desc_path, desc_bytes,
1624+
options={"mode": "PERMISSIVE"})
1625+
return df.select(
1626+
decoded.isNull().alias("decoded_is_null"),
1627+
decoded.getField("id").alias("id"),
1628+
decoded.getField("name").alias("name"),
1629+
decoded.getField("count").alias("count")
1630+
)
1631+
1632+
assert_gpu_and_cpu_are_equal_collect(run_on_spark)
1633+
1634+
16101635
@pytest.mark.skipif(is_before_spark_340(), reason="from_protobuf is Spark 3.4.0+")
16111636
@ignore_order(local=True)
16121637
def test_from_protobuf_nested_required_field_missing_permissive(

sql-plugin/src/main/scala/org/apache/spark/sql/rapids/protobuf/ProtobufSchemaValidator.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.apache.spark.sql.rapids.protobuf
1818

1919
import ai.rapids.cudf.DType
20+
2021
import org.apache.spark.sql.rapids.GpuFromProtobuf
2122
import org.apache.spark.sql.types._
2223

sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/SparkProtobufCompat.scala

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,18 @@ private[shims] object SparkProtobufCompat extends Logging {
230230
override lazy val defaultValueResult: Either[String, Option[ProtobufDefaultValue]] =
231231
Try {
232232
if (PbReflect.hasDefaultValue(raw)) {
233-
PbReflect.getDefaultValue(raw).map(toDefaultValue(_, protoTypeName, enumMetadata))
233+
PbReflect.getDefaultValue(raw) match {
234+
case Some(default) =>
235+
toDefaultValue(default, protoTypeName, enumMetadata).map(Some(_))
236+
case None =>
237+
Right(None)
238+
}
234239
} else {
235-
None
240+
Right(None)
236241
}
237242
}.toEither.left.map { t =>
238243
s"Failed to read protobuf default value for field '$name': ${t.getMessage}"
239-
}
244+
}.flatMap(identity)
240245
override lazy val messageDescriptor: Option[ProtobufMessageDescriptor] =
241246
if (protoTypeName == "MESSAGE") {
242247
Some(new ReflectiveMessageDescriptor(PbReflect.getMessageType(raw)))
@@ -248,26 +253,31 @@ private[shims] object SparkProtobufCompat extends Logging {
248253
private def toDefaultValue(
249254
rawDefault: AnyRef,
250255
protoTypeName: String,
251-
enumMetadata: Option[ProtobufEnumMetadata]): ProtobufDefaultValue = protoTypeName match {
256+
enumMetadata: Option[ProtobufEnumMetadata]): Either[String, ProtobufDefaultValue] =
257+
protoTypeName match {
252258
case "BOOL" =>
253-
ProtobufDefaultValue.BoolValue(rawDefault.asInstanceOf[java.lang.Boolean].booleanValue())
259+
Right(ProtobufDefaultValue.BoolValue(
260+
rawDefault.asInstanceOf[java.lang.Boolean].booleanValue()))
254261
case "FLOAT" =>
255-
ProtobufDefaultValue.FloatValue(rawDefault.asInstanceOf[java.lang.Float].floatValue())
262+
Right(ProtobufDefaultValue.FloatValue(
263+
rawDefault.asInstanceOf[java.lang.Float].floatValue()))
256264
case "DOUBLE" =>
257-
ProtobufDefaultValue.DoubleValue(rawDefault.asInstanceOf[java.lang.Double].doubleValue())
265+
Right(ProtobufDefaultValue.DoubleValue(
266+
rawDefault.asInstanceOf[java.lang.Double].doubleValue()))
258267
case "STRING" =>
259-
ProtobufDefaultValue.StringValue(if (rawDefault == null) null else rawDefault.toString)
268+
Right(ProtobufDefaultValue.StringValue(
269+
if (rawDefault == null) null else rawDefault.toString))
260270
case "BYTES" =>
261-
ProtobufDefaultValue.BinaryValue(extractBytes(rawDefault))
271+
Right(ProtobufDefaultValue.BinaryValue(extractBytes(rawDefault)))
262272
case "ENUM" =>
263273
val number = extractNumber(rawDefault).intValue()
264-
enumMetadata.map(_.enumDefault(number))
265-
.getOrElse(ProtobufDefaultValue.EnumValue(number, rawDefault.toString))
274+
Right(enumMetadata.map(_.enumDefault(number))
275+
.getOrElse(ProtobufDefaultValue.EnumValue(number, rawDefault.toString)))
266276
case "INT32" | "UINT32" | "SINT32" | "FIXED32" | "SFIXED32" |
267277
"INT64" | "UINT64" | "SINT64" | "FIXED64" | "SFIXED64" =>
268-
ProtobufDefaultValue.IntValue(extractNumber(rawDefault).longValue())
278+
Right(ProtobufDefaultValue.IntValue(extractNumber(rawDefault).longValue()))
269279
case other =>
270-
throw new IllegalStateException(
280+
Left(
271281
s"Unsupported protobuf default value type '$other' for value ${rawDefault.toString}")
272282
}
273283

sql-plugin/src/test/scala/com/nvidia/spark/rapids/shims/ProtobufExprShimsSuite.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ class ProtobufExprShimsSuite extends AnyFunSuite {
125125
}
126126

127127
private object FakeSpark35RetryFailureProtobufUtils {
128-
def buildDescriptor(messageName: String, binaryFileDescriptorSet: Option[Array[Byte]]): String = {
128+
def buildDescriptor(
129+
messageName: String,
130+
binaryFileDescriptorSet: Option[Array[Byte]]): String = {
129131
val bytes = binaryFileDescriptorSet.getOrElse(Array.emptyByteArray)
130132
if (bytes.sameElements(Array[Byte](1, 2, 3))) {
131133
throw new IllegalArgumentException(s"Unknown message $messageName")
@@ -280,6 +282,23 @@ class ProtobufExprShimsSuite extends AnyFunSuite {
280282
assert(SparkProtobufCompat.isGpuSupportedProtoSyntax("PROTO2"))
281283
}
282284

285+
test("compat returns Left for unsupported default value types") {
286+
val method = SparkProtobufCompat.getClass.getDeclaredMethod(
287+
"toDefaultValue",
288+
classOf[AnyRef],
289+
classOf[String],
290+
classOf[scala.Option[_]])
291+
method.setAccessible(true)
292+
293+
val result = method.invoke(
294+
SparkProtobufCompat,
295+
"opaque-default",
296+
"MESSAGE",
297+
scala.None).asInstanceOf[Either[String, ProtobufDefaultValue]]
298+
299+
assert(result.left.toOption.exists(_.contains("Unsupported protobuf default value type")))
300+
}
301+
283302
test("extractor preserves typed enum defaults") {
284303
val enumMeta = ProtobufEnumMetadata(Seq(
285304
ProtobufEnumValue(0, "UNKNOWN"),

0 commit comments

Comments
 (0)