Skip to content

Commit 7fa6bb3

Browse files
committed
[SPARK-51316][PYTHON][FOLLOW-UP] Revert unrelated changes and mark mapInPandas/mapInArrow batched in byte size
### What changes were proposed in this pull request? This PR is a followup of apache/spark#50096 that reverts unrelated changes and mark mapInPandas/mapInArrow batched in byte size ### Why are the changes needed? To make the original change self-contained, and mark mapInPandas/mapInArrow batched in byte size to be consistent. ### Does this PR introduce _any_ user-facing change? No, the main change has not been released out yet. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #50111 from HyukjinKwon/SPARK-51316-followup. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 5f38636 commit 7fa6bb3

File tree

48 files changed

+879
-320
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+879
-320
lines changed

common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.math.BigDecimal;
2121
import java.util.ArrayList;
22+
import java.util.UUID;
2223

2324
import static org.apache.spark.types.variant.VariantUtil.*;
2425

@@ -37,6 +38,7 @@ public interface ShreddedRow {
3738
BigDecimal getDecimal(int ordinal, int precision, int scale);
3839
String getString(int ordinal);
3940
byte[] getBinary(int ordinal);
41+
UUID getUuid(int ordinal);
4042
ShreddedRow getStruct(int ordinal, int numFields);
4143
ShreddedRow getArray(int ordinal);
4244
int numElements();
@@ -99,6 +101,8 @@ public static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema schem
99101
builder.appendBoolean(row.getBoolean(typedIdx));
100102
} else if (scalar instanceof VariantSchema.BinaryType) {
101103
builder.appendBinary(row.getBinary(typedIdx));
104+
} else if (scalar instanceof VariantSchema.UuidType) {
105+
builder.appendUuid(row.getUuid(typedIdx));
102106
} else if (scalar instanceof VariantSchema.DecimalType) {
103107
VariantSchema.DecimalType dt = (VariantSchema.DecimalType) scalar;
104108
builder.appendDecimal(row.getDecimal(typedIdx, dt.precision, dt.scale));

common/variant/src/main/java/org/apache/spark/types/variant/Variant.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.Arrays;
3434
import java.util.Base64;
3535
import java.util.Locale;
36+
import java.util.UUID;
3637

3738
import static org.apache.spark.types.variant.VariantUtil.*;
3839

@@ -123,6 +124,11 @@ public Type getType() {
123124
return VariantUtil.getType(value, pos);
124125
}
125126

127+
// Get a UUID value from the variant.
128+
public UUID getUuid() {
129+
return VariantUtil.getUuid(value, pos);
130+
}
131+
126132
// Get the number of object fields in the variant.
127133
// It is only legal to call it when `getType()` is `Type.OBJECT`.
128134
public int objectSize() {
@@ -333,6 +339,9 @@ static void toJsonImpl(byte[] value, byte[] metadata, int pos, StringBuilder sb,
333339
case BINARY:
334340
appendQuoted(sb, Base64.getEncoder().encodeToString(VariantUtil.getBinary(value, pos)));
335341
break;
342+
case UUID:
343+
appendQuoted(sb, VariantUtil.getUuid(value, pos).toString());
344+
break;
336345
}
337346
}
338347
}

common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import java.math.BigDecimal;
2727
import java.math.BigInteger;
2828
import java.nio.charset.StandardCharsets;
29+
import java.nio.ByteBuffer;
30+
import java.nio.ByteOrder;
2931
import java.util.*;
3032

3133
import com.fasterxml.jackson.core.JsonFactory;
@@ -240,6 +242,18 @@ public void appendBinary(byte[] binary) {
240242
writePos += binary.length;
241243
}
242244

245+
public void appendUuid(UUID uuid) {
246+
checkCapacity(1 + 16);
247+
writeBuffer[writePos++] = primitiveHeader(UUID);
248+
249+
// UUID is stored big-endian, so don't use writeLong.
250+
ByteBuffer buffer = ByteBuffer.wrap(writeBuffer, writePos, 16);
251+
buffer.order(ByteOrder.BIG_ENDIAN);
252+
buffer.putLong(writePos, uuid.getMostSignificantBits());
253+
buffer.putLong(writePos + 8, uuid.getLeastSignificantBits());
254+
writePos += 16;
255+
}
256+
243257
// Add a key to the variant dictionary. If the key already exists, the dictionary is not modified.
244258
// In either case, return the id of the key.
245259
public int addKey(String key) {

common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ public static final class TimestampType extends ScalarType {
9999
public static final class TimestampNTZType extends ScalarType {
100100
}
101101

102+
public static final class UuidType extends ScalarType {
103+
}
104+
102105
// The index of the typed_value, value, and metadata fields in the schema, respectively. If a
103106
// given field is not in the schema, its value must be set to -1 to indicate that it is invalid.
104107
// The indices of valid fields should be contiguous and start from 0.

common/variant/src/main/java/org/apache/spark/types/variant/VariantShreddingWriter.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ private static Object tryTypedShred(
283283
return v.getBinary();
284284
}
285285
break;
286+
case UUID:
287+
if (targetType instanceof VariantSchema.UuidType) {
288+
return v.getUuid();
289+
}
290+
break;
286291
}
287292
// The stored type does not match the requested shredding type. Return null, and the caller
288293
// will store the result in untyped_value.

common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323

2424
import java.math.BigDecimal;
2525
import java.math.BigInteger;
26+
import java.nio.ByteBuffer;
27+
import java.nio.ByteOrder;
2628
import java.util.Arrays;
29+
import java.util.UUID;
2730

2831
/**
2932
* This class defines constants related to the variant format and provides functions for
@@ -121,6 +124,9 @@ public class VariantUtil {
121124
// string size) + (size bytes of string content).
122125
public static final int LONG_STR = 16;
123126

127+
// UUID, 16-byte big-endian.
128+
public static final int UUID = 20;
129+
124130
public static final byte VERSION = 1;
125131
// The lower 4 bits of the first metadata byte contain the version.
126132
public static final byte VERSION_MASK = 0x0F;
@@ -239,6 +245,7 @@ public enum Type {
239245
TIMESTAMP_NTZ,
240246
FLOAT,
241247
BINARY,
248+
UUID,
242249
}
243250

244251
public static int getTypeInfo(byte[] value, int pos) {
@@ -291,6 +298,8 @@ public static Type getType(byte[] value, int pos) {
291298
return Type.BINARY;
292299
case LONG_STR:
293300
return Type.STRING;
301+
case UUID:
302+
return Type.UUID;
294303
default:
295304
throw unknownPrimitiveTypeInVariant(typeInfo);
296305
}
@@ -342,6 +351,8 @@ public static int valueSize(byte[] value, int pos) {
342351
case BINARY:
343352
case LONG_STR:
344353
return 1 + U32_SIZE + readUnsigned(value, pos + 1, U32_SIZE);
354+
case UUID:
355+
return 17;
345356
default:
346357
throw unknownPrimitiveTypeInVariant(typeInfo);
347358
}
@@ -497,6 +508,20 @@ public static String getString(byte[] value, int pos) {
497508
throw unexpectedType(Type.STRING);
498509
}
499510

511+
// Get a UUID value from variant value `value[pos...]`.
512+
// Throw `MALFORMED_VARIANT` if the variant is malformed.
513+
public static UUID getUuid(byte[] value, int pos) {
514+
checkIndex(pos, value.length);
515+
int basicType = value[pos] & BASIC_TYPE_MASK;
516+
int typeInfo = (value[pos] >> BASIC_TYPE_BITS) & TYPE_INFO_MASK;
517+
if (basicType != PRIMITIVE || typeInfo != UUID) throw unexpectedType(Type.UUID);
518+
int start = pos + 1;
519+
checkIndex(start + 15, value.length);
520+
// UUID values are big-endian, so we can't use VariantUtil.readLong().
521+
ByteBuffer bb = ByteBuffer.wrap(value, start, 16).order(ByteOrder.BIG_ENDIAN);
522+
return new UUID(bb.getLong(), bb.getLong());
523+
}
524+
500525
public interface ObjectHandler<T> {
501526
/**
502527
* @param size Number of object fields.

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
185185
"('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate()
186186
connection.prepareStatement("INSERT INTO datetime VALUES " +
187187
"('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate()
188+
// '2022-01-01' is Saturday and is in ISO year 2021.
189+
connection.prepareStatement("INSERT INTO datetime VALUES " +
190+
"('tom', '2022-01-01', '2022-01-01 00:00:00')").executeUpdate()
188191
}
189192

190193
override def testUpdateColumnType(tbl: String): Unit = {
@@ -279,17 +282,19 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
279282
val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0")
280283
checkFilterPushed(df4)
281284
val rows4 = df4.collect()
282-
assert(rows4.length === 2)
285+
assert(rows4.length === 3)
283286
assert(rows4(0).getString(0) === "amy")
284287
assert(rows4(1).getString(0) === "alex")
288+
assert(rows4(2).getString(0) === "tom")
285289

286290
val df5 = sql(s"SELECT name FROM $tbl WHERE " +
287-
"extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022")
291+
"extract(WEEK from date1) > 10 AND extract(YEAR from date1) = 2022")
288292
checkFilterPushed(df5)
289293
val rows5 = df5.collect()
290-
assert(rows5.length === 2)
294+
assert(rows5.length === 3)
291295
assert(rows5(0).getString(0) === "amy")
292296
assert(rows5(1).getString(0) === "alex")
297+
assert(rows5(2).getString(0) === "tom")
293298

294299
val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " +
295300
"AND datediff(date1, '2022-05-10') > 0")
@@ -304,11 +309,25 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
304309
assert(rows7.length === 1)
305310
assert(rows7(0).getString(0) === "alex")
306311

307-
val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4")
308-
checkFilterPushed(df8)
309-
val rows8 = df8.collect()
310-
assert(rows8.length === 1)
311-
assert(rows8(0).getString(0) === "alex")
312+
withClue("dayofweek") {
313+
val dow = sql(s"SELECT dayofweek(date1) FROM $tbl WHERE name = 'alex'")
314+
.collect().head.getInt(0)
315+
val df = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = $dow")
316+
checkFilterPushed(df)
317+
val rows = df.collect()
318+
assert(rows.length === 1)
319+
assert(rows(0).getString(0) === "alex")
320+
}
321+
322+
withClue("yearofweek") {
323+
val yow = sql(s"SELECT extract(YEAROFWEEK from date1) FROM $tbl WHERE name = 'tom'")
324+
.collect().head.getInt(0)
325+
val df = sql(s"SELECT name FROM $tbl WHERE extract(YEAROFWEEK from date1) = $yow")
326+
checkFilterPushed(df)
327+
val rows = df.collect()
328+
assert(rows.length === 1)
329+
assert(rows(0).getString(0) === "tom")
330+
}
312331

313332
val df9 = sql(s"SELECT name FROM $tbl WHERE " +
314333
"dayofyear(date1) > 100 order by dayofyear(date1) limit 1")

python/docs/source/development/logger.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,14 @@ Example log entry:
5757
"exception": {
5858
"class": "Py4JJavaError",
5959
"msg": "An error occurred while calling o52.showString.\n: org.apache.spark.SparkArithmeticException: [DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error. SQLSTATE: 22012\n== DataFrame ==\n\"divide\" was called from\n/path/to/file.py:17 ...",
60-
"stacktrace": ["Traceback (most recent call last):", " File \".../spark/python/pyspark/errors/exceptions/captured.py\", line 247, in deco", " return f(*a, **kw)", " File \".../lib/python3.9/site-packages/py4j/protocol.py\", line 326, in get_return_value" ...]
60+
"stacktrace": [
61+
{
62+
"class": null,
63+
"method": "deco",
64+
"file": ".../spark/python/pyspark/errors/exceptions/captured.py",
65+
"line": "247"
66+
}
67+
]
6168
},
6269
}
6370

python/pyspark/sql/connect/column.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
Expression,
4545
UnresolvedFunction,
4646
UnresolvedExtractValue,
47-
LazyExpression,
4847
LiteralExpression,
4948
CaseWhen,
5049
SortOrder,
@@ -459,7 +458,7 @@ def over(self, window: "WindowSpec") -> ParentColumn: # type: ignore[override]
459458
return Column(WindowExpression(windowFunction=self._expr, windowSpec=window))
460459

461460
def outer(self) -> ParentColumn:
462-
return Column(LazyExpression(self._expr))
461+
return Column(self._expr)
463462

464463
def isin(self, *cols: Any) -> ParentColumn:
465464
if len(cols) == 1 and isinstance(cols[0], (list, set)):

python/pyspark/sql/connect/expressions.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,25 +1233,6 @@ def __repr__(self) -> str:
12331233
return f"{self._key} => {self._value}"
12341234

12351235

1236-
class LazyExpression(Expression):
1237-
def __init__(self, expr: Expression):
1238-
assert isinstance(expr, Expression)
1239-
super().__init__()
1240-
self._expr = expr
1241-
1242-
def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
1243-
expr = self._create_proto_expression()
1244-
expr.lazy_expression.child.CopyFrom(self._expr.to_plan(session))
1245-
return expr
1246-
1247-
@property
1248-
def children(self) -> Sequence["Expression"]:
1249-
return [self._expr]
1250-
1251-
def __repr__(self) -> str:
1252-
return f"lazy({self._expr})"
1253-
1254-
12551236
class SubqueryExpression(Expression):
12561237
def __init__(
12571238
self,

0 commit comments

Comments
 (0)