Skip to content

Commit 33299cf

Browse files
authored
Bug Fix for Spark 3.x - Avoid converting converted Row values (#868)
1 parent b9283eb commit 33299cf

File tree

11 files changed

+196
-104
lines changed

11 files changed

+196
-104
lines changed

azure-pipelines.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,12 @@ variables:
3333

3434
# Filter DataFrameTests.TestDataFrameGroupedMapUdf and DataFrameTests.TestGroupedMapUdf backwardCompatible
3535
# tests due to https://github.com/dotnet/spark/pull/711
36+
# Filter UdfSimpleTypesTests.TestUdfWithDuplicateTimestamps 3.x backwardCompatible test due to bug related
37+
# to duplicate types that NeedConversion. Bug fixed in https://github.com/dotnet/spark/pull/868
3638
backwardCompatibleTestOptions_Windows_3_0: "--filter \
3739
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestDataFrameGroupedMapUdf)&\
38-
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestGroupedMapUdf)"
40+
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestGroupedMapUdf)&\
41+
(FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfSimpleTypesTests.TestUdfWithDuplicateTimestamps)"
3942
forwardCompatibleTestOptions_Windows_3_0: ""
4043
backwardCompatibleTestOptions_Linux_3_0: $(backwardCompatibleTestOptions_Windows_3_0)
4144
forwardCompatibleTestOptions_Linux_3_0: $(forwardCompatibleTestOptions_Linux_2_4)
@@ -85,7 +88,8 @@ variables:
8588
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestUDF)&\
8689
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.SparkSessionExtensionsTests.TestVersion)&\
8790
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataStreamWriterTests.TestForeachBatch)&\
88-
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataStreamWriterTests.TestForeach)"
91+
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataStreamWriterTests.TestForeach)&\
92+
(FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfSimpleTypesTests.TestUdfWithDuplicateTimestamps)"
8993
# Skip all forwardCompatible tests since microsoft-spark-3-1 jar does not get built when
9094
# building forwardCompatible repo.
9195
forwardCompatibleTestOptions_Windows_3_1: "--filter FullyQualifiedName=NONE"
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Linq;
6+
using Microsoft.Spark.Sql;
7+
using Microsoft.Spark.Sql.Types;
8+
using Xunit;
9+
using static Microsoft.Spark.Sql.Functions;
10+
11+
namespace Microsoft.Spark.E2ETest.IpcTests
12+
13+
{
14+
[Collection("Spark E2E Tests")]
15+
public class RowTests
16+
{
17+
private readonly SparkSession _spark;
18+
19+
public RowTests(SparkFixture fixture)
20+
{
21+
_spark = fixture.Spark;
22+
}
23+
24+
[Fact]
25+
public void TestWithDuplicatedRows()
26+
{
27+
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
28+
var schema = new StructType(new StructField[]
29+
{
30+
new StructField("ts", new TimestampType())
31+
});
32+
var data = new GenericRow[]
33+
{
34+
new GenericRow(new object[] { timestamp })
35+
};
36+
37+
DataFrame df = _spark.CreateDataFrame(data, schema);
38+
Row[] rows = df
39+
.WithColumn("tsRow", Struct("ts"))
40+
.WithColumn("tsRowRow", Struct("tsRow"))
41+
.Collect()
42+
.ToArray();
43+
44+
Assert.Single(rows);
45+
46+
Row row = rows[0];
47+
Assert.Equal(3, row.Values.Length);
48+
Assert.Equal(timestamp, row.Values[0]);
49+
50+
Row tsRow = row.Values[1] as Row;
51+
Assert.Single(tsRow.Values);
52+
Assert.Equal(timestamp, tsRow.Values[0]);
53+
54+
Row tsRowRow = row.Values[2] as Row;
55+
Assert.Single(tsRowRow.Values);
56+
Assert.Equal(tsRowRow.Values[0], tsRow);
57+
}
58+
59+
[Fact]
60+
public void TestWithDuplicateTimestamps()
61+
{
62+
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
63+
var schema = new StructType(new StructField[]
64+
{
65+
new StructField("ts", new TimestampType())
66+
});
67+
var data = new GenericRow[]
68+
{
69+
new GenericRow(new object[] { timestamp }),
70+
new GenericRow(new object[] { timestamp }),
71+
new GenericRow(new object[] { timestamp })
72+
};
73+
74+
DataFrame df = _spark.CreateDataFrame(data, schema);
75+
Row[] rows = df.Collect().ToArray();
76+
77+
Assert.Equal(3, rows.Length);
78+
foreach (Row row in rows)
79+
{
80+
Assert.Single(row.Values);
81+
Assert.Equal(timestamp, row.GetAs<Timestamp>(0));
82+
}
83+
}
84+
85+
[Fact]
86+
public void TestWithDuplicateDates()
87+
{
88+
var date = new Date(2020, 1, 1);
89+
var schema = new StructType(new StructField[]
90+
{
91+
new StructField("date", new DateType())
92+
});
93+
var data = new GenericRow[]
94+
{
95+
new GenericRow(new object[] { date }),
96+
new GenericRow(new object[] { date }),
97+
new GenericRow(new object[] { date })
98+
};
99+
100+
DataFrame df = _spark.CreateDataFrame(data, schema);
101+
102+
Row[] rows = df.Collect().ToArray();
103+
104+
Assert.Equal(3, rows.Length);
105+
foreach (Row row in rows)
106+
{
107+
Assert.Single(row.Values);
108+
Assert.Equal(date, row.GetAs<Date>(0));
109+
}
110+
}
111+
}
112+
}

src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/Streaming/DataStreamWriterTests.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
using System.Collections.Generic;
77
using System.IO;
88
using System.Linq;
9-
using System.Threading;
109
using Microsoft.Spark.E2ETest.Utils;
1110
using Microsoft.Spark.Sql;
1211
using Microsoft.Spark.Sql.Streaming;

src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,40 @@ public void TestUdfWithTimestampType()
125125
Assert.Equal(expected, rowsToArray);
126126
}
127127

128+
/// <summary>
129+
/// UDF that returns a timestamp string.
130+
/// </summary>
131+
[Fact]
132+
public void TestUdfWithDuplicateTimestamps()
133+
{
134+
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
135+
var schema = new StructType(new StructField[]
136+
{
137+
new StructField("ts", new TimestampType())
138+
});
139+
var data = new GenericRow[]
140+
{
141+
new GenericRow(new object[] { timestamp }),
142+
new GenericRow(new object[] { timestamp }),
143+
new GenericRow(new object[] { timestamp })
144+
};
145+
146+
var expectedTimestamp = new Timestamp(1970, 1, 2, 0, 0, 0, 0);
147+
Func<Column, Column> udf = Udf<Timestamp, Timestamp>(
148+
ts => new Timestamp(1970, 1, 2, 0, 0, 0, 0));
149+
150+
DataFrame df = _spark.CreateDataFrame(data, schema);
151+
152+
Row[] rows = df.Select(udf(df["ts"])).Collect().ToArray();
153+
154+
Assert.Equal(3, rows.Length);
155+
foreach (Row row in rows)
156+
{
157+
Assert.Single(row.Values);
158+
Assert.Equal(expectedTimestamp, row.Values[0]);
159+
}
160+
}
161+
128162
/// <summary>
129163
/// UDF that returns Timestamp type.
130164
/// </summary>

src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ public void RowConstructorTest()
9191
pickledBytes.Length);
9292

9393
Assert.Equal(2, unpickledData.Length);
94-
Assert.Equal(row1, (unpickledData[0] as RowConstructor).GetRow());
95-
Assert.Equal(row2, (unpickledData[1] as RowConstructor).GetRow());
94+
Assert.Equal(row1, unpickledData[0]);
95+
Assert.Equal(row2, unpickledData[1]);
9696
}
9797

9898
[Fact]

src/csharp/Microsoft.Spark.Worker/Command/SqlCommandExecutor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ protected internal override CommandExecutorStat ExecuteCore(
143143
// The following can happen if an UDF takes Row object(s).
144144
// The JVM Spark side sends a Row object that wraps all the columns used
145145
// in the UDF, thus, it is normalized below (the extra layer is removed).
146-
if (row is RowConstructor rowConstructor)
146+
if (row is Row r)
147147
{
148-
row = rowConstructor.GetRow().Values;
148+
row = r.Values;
149149
}
150150

151151
// Split id is not used for SQL UDFs, so 0 is passed.

src/csharp/Microsoft.Spark/RDD/Collector.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,7 @@ public object Deserialize(Stream stream, int length)
9898
{
9999
// Refer to the AutoBatchedPickler class in spark/core/src/main/scala/org/apache/
100100
// spark/api/python/SerDeUtil.scala regarding how the Rows may be batched.
101-
return PythonSerDe.GetUnpickledObjects(stream, length)
102-
.Cast<RowConstructor>()
103-
.Select(rc => rc.GetRow())
104-
.ToArray();
101+
return PythonSerDe.GetUnpickledObjects(stream, length).Cast<Row>().ToArray();
105102
}
106103
}
107104
}

src/csharp/Microsoft.Spark/Sql/RowCollector.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public IEnumerable<Row> Collect(ISocketWrapper socket)
3535

3636
foreach (object unpickled in unpickledObjects)
3737
{
38-
yield return (unpickled as RowConstructor).GetRow();
38+
yield return unpickled as Row;
3939
}
4040
}
4141
}

src/csharp/Microsoft.Spark/Sql/RowConstructor.cs

Lines changed: 37 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -22,94 +22,39 @@ internal sealed class RowConstructor : IObjectConstructor
2222
/// sent per batch if there are nested rows contained in the row. Note that
2323
/// this is thread local variable because one RowConstructor object is
2424
/// registered to the Unpickler and there could be multiple threads unpickling
25-
/// the data using the same object registered.
25+
/// the data using the same registered object.
2626
/// </summary>
2727
[ThreadStatic]
2828
private static IDictionary<string, StructType> s_schemaCache;
2929

3030
/// <summary>
31-
/// The RowConstructor that created this instance.
31+
/// Used by Unpickler to pass unpickled schema for handling. The Unpickler
32+
/// will reuse the <see cref="RowConstructor"/> object when
33+
/// it needs to start constructing a <see cref="Row"/>. The schema is passed
34+
/// to <see cref="construct(object[])"/> and the returned
35+
/// <see cref="IObjectConstructor"/> is used to build the rest of the <see cref="Row"/>.
3236
/// </summary>
33-
private readonly RowConstructor _parent;
34-
35-
/// <summary>
36-
/// Stores the args passed from construct().
37-
/// </summary>
38-
private readonly object[] _args;
39-
40-
public RowConstructor() : this(null, null)
41-
{
42-
}
43-
44-
public RowConstructor(RowConstructor parent, object[] args)
45-
{
46-
_parent = parent;
47-
_args = args;
48-
}
49-
50-
/// <summary>
51-
/// Used by Unpickler to pass unpickled data for handling.
52-
/// </summary>
53-
/// <param name="args">Unpickled data</param>
54-
/// <returns>New RowConstructor object capturing args data</returns>
37+
/// <param name="args">Unpickled schema</param>
38+
/// <returns>
39+
/// New <see cref="RowWithSchemaConstructor"/>object capturing the schema.
40+
/// </returns>
5541
public object construct(object[] args)
5642
{
57-
// Every first call to construct() contains the schema data. When
58-
// a new RowConstructor object is returned from this function,
59-
// construct() is called on the returned object with the actual
60-
// row data. The original RowConstructor object may be reused by the
61-
// Unpickler and each subsequent construct() call can contain the
62-
// schema data or a RowConstructor object that contains row data.
6343
if (s_schemaCache is null)
6444
{
6545
s_schemaCache = new Dictionary<string, StructType>();
6646
}
6747

68-
// Return a new RowConstructor where the args either represent the
69-
// schema or the row data. The parent becomes important when calling
70-
// GetRow() on the RowConstructor containing the row data.
71-
//
72-
// - When args is the schema, return a new RowConstructor where the
73-
// parent is set to the calling RowConstructor object.
74-
//
75-
// - In the case where args is the row data, construct() is called on a
76-
// RowConstructor object that contains the schema for the row data. A
77-
// new RowConstructor is returned where the parent is set to the schema
78-
// containing RowConstructor.
79-
return new RowConstructor(this, args);
80-
}
81-
82-
/// <summary>
83-
/// Construct a Row object from unpickled data. This is only to be called
84-
/// on a RowConstructor that contains the row data.
85-
/// </summary>
86-
/// <returns>A row object with unpickled data</returns>
87-
public Row GetRow()
88-
{
89-
Debug.Assert(_parent != null);
90-
91-
// It is possible that an entry of a Row (row1) may itself be a Row (row2).
92-
// If the entry is a RowConstructor then it will be a RowConstructor
93-
// which contains the data for row2. Therefore we will call GetRow()
94-
// on the RowConstructor to materialize row2 and replace the RowConstructor
95-
// entry in row1.
96-
for (int i = 0; i < _args.Length; ++i)
97-
{
98-
if (_args[i] is RowConstructor rowConstructor)
99-
{
100-
_args[i] = rowConstructor.GetRow();
101-
}
102-
}
103-
104-
return new Row(_args, _parent.GetSchema());
48+
Debug.Assert((args != null) && (args.Length == 1) && (args[0] is string));
49+
return new RowWithSchemaConstructor(GetSchema(s_schemaCache, (string)args[0]));
10550
}
10651

10752
/// <summary>
10853
/// Clears the schema cache. Spark sends rows in batches and for each
10954
/// row there is an accompanying set of schemas and row entries. If the
11055
/// schema was not cached, then it would need to be parsed and converted
11156
/// to a StructType for every row in the batch. A new batch may contain
112-
/// rows from a different table, so calling <c>Reset</c> after each
57+
/// rows from a different table, so calling <see cref="Reset"/> after each
11358
/// batch would aid in preventing the cache from growing too large.
11459
/// Caching the schemas for each batch, ensures that each schema is
11560
/// only parsed and converted to a StructType once per batch.
@@ -119,23 +64,36 @@ internal void Reset()
11964
s_schemaCache?.Clear();
12065
}
12166

122-
/// <summary>
123-
/// Get or cache the schema string contained in args. Calling this
124-
/// is only valid if the child args contain the row values.
125-
/// </summary>
126-
/// <returns></returns>
127-
private StructType GetSchema()
67+
private static StructType GetSchema(IDictionary<string, StructType> schemaCache, string schemaString)
12868
{
129-
Debug.Assert(s_schemaCache != null);
130-
Debug.Assert((_args != null) && (_args.Length == 1) && (_args[0] is string));
131-
var schemaString = (string)_args[0];
132-
if (!s_schemaCache.TryGetValue(schemaString, out StructType schema))
69+
if (!schemaCache.TryGetValue(schemaString, out StructType schema))
13370
{
13471
schema = (StructType)DataType.ParseDataType(schemaString);
135-
s_schemaCache.Add(schemaString, schema);
72+
schemaCache.Add(schemaString, schema);
13673
}
13774

13875
return schema;
13976
}
14077
}
78+
79+
/// <summary>
80+
/// Created from <see cref="RowConstructor"/> and subsequently used
81+
/// by the Unpickler to construct a <see cref="Row"/>.
82+
/// </summary>
83+
internal sealed class RowWithSchemaConstructor : IObjectConstructor
84+
{
85+
private readonly StructType _schema;
86+
87+
internal RowWithSchemaConstructor(StructType schema)
88+
{
89+
_schema = schema;
90+
}
91+
92+
/// <summary>
93+
/// Used by Unpickler to pass unpickled row values for handling.
94+
/// </summary>
95+
/// <param name="args">Unpickled row values.</param>
96+
/// <returns>Row object.</returns>
97+
public object construct(object[] args) => new Row(args, _schema);
98+
}
14199
}

0 commit comments

Comments
 (0)