Skip to content

Commit 1cd9cca

Browse files
UDF bug fix caused by ThreadStatic BroadcastVariablesRegistry (#551)
1 parent 7bb3dd1 commit 1cd9cca

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8+
using System.Threading;
89
using Microsoft.Spark.Sql;
910
using Microsoft.Spark.Sql.Types;
1011
using Xunit;
@@ -166,5 +167,29 @@ public void TestUdfWithReturnAsTimestampType()
166167
}
167168
}
168169
}
170+
171+
/// <summary>
172+
/// Test to validate UDFs defined in separate threads work.
173+
/// </summary>
174+
[Fact]
175+
public void TestUdfWithMultipleThreads()
176+
{
177+
try
178+
{
179+
void DefineUdf() => Udf<string, string>(str => str);
180+
181+
// Define a UDF in the main thread.
182+
Udf<string, string>(str => str);
183+
184+
// Verify a UDF can be defined in a separate thread.
185+
Thread t = new Thread(DefineUdf);
186+
t.Start();
187+
t.Join();
188+
}
189+
catch (Exception)
190+
{
191+
Assert.True(false);
192+
}
193+
}
169194
}
170195
}

src/csharp/Microsoft.Spark/Broadcast.cs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.IO;
55
using System.Runtime.Serialization;
66
using System.Runtime.Serialization.Formatters.Binary;
7+
using System.Threading;
78
using Microsoft.Spark.Interop;
89
using Microsoft.Spark.Interop.Ipc;
910
using Microsoft.Spark.Services;
@@ -261,28 +262,26 @@ internal static void Remove(long bid)
261262
/// </summary>
262263
internal static class JvmBroadcastRegistry
263264
{
264-
[ThreadStatic]
265-
private static readonly List<JvmObjectReference> s_jvmBroadcastVariables =
266-
new List<JvmObjectReference>();
265+
private static ThreadLocal<List<JvmObjectReference>> s_jvmBroadcastVariables =
266+
new ThreadLocal<List<JvmObjectReference>>(() => new List<JvmObjectReference>());
267267

268268
/// <summary>
269269
/// Adds a JVMObjectReference object of type <see cref="Broadcast{T}"/> to the list.
270270
/// </summary>
271271
/// <param name="broadcastJvmObject">JVMObjectReference of the Broadcast variable</param>
272272
internal static void Add(JvmObjectReference broadcastJvmObject) =>
273-
s_jvmBroadcastVariables.Add(broadcastJvmObject);
273+
s_jvmBroadcastVariables.Value.Add(broadcastJvmObject);
274274

275275
/// <summary>
276276
/// Clears s_jvmBroadcastVariables of all the JVMObjectReference objects of type
277277
/// <see cref="Broadcast{T}"/>.
278278
/// </summary>
279-
internal static void Clear() => s_jvmBroadcastVariables.Clear();
279+
internal static void Clear() => s_jvmBroadcastVariables.Value.Clear();
280280

281281
/// <summary>
282282
/// Returns the static member s_jvmBroadcastVariables.
283283
/// </summary>
284284
/// <returns>A list of all broadcast objects of type <see cref="JvmObjectReference"/></returns>
285-
internal static List<JvmObjectReference> GetAll() => s_jvmBroadcastVariables;
285+
internal static List<JvmObjectReference> GetAll() => s_jvmBroadcastVariables.Value;
286286
}
287287
}
288-

0 commit comments

Comments
 (0)