Skip to content

Commit 4dbfb34

Browse files
authored
Adding DaemonWorkerTests (#379)
1 parent 0db249b commit 4dbfb34

File tree

5 files changed

+164
-64
lines changed

5 files changed

+164
-64
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections;
7+
using System.Collections.Generic;
8+
using System.IO;
9+
using System.Net;
10+
using System.Threading.Tasks;
11+
using Microsoft.Spark.Interop.Ipc;
12+
using Microsoft.Spark.Network;
13+
using Razorvine.Pickle;
14+
using Xunit;
15+
16+
namespace Microsoft.Spark.Worker.UnitTest
17+
{
18+
public class DaemonWorkerTests
19+
{
20+
[Fact]
21+
public void TestsDaemonWorkerTaskRunners()
22+
{
23+
ISocketWrapper daemonSocket = SocketFactory.CreateSocket();
24+
25+
int taskRunnerNumber = 3;
26+
var typedVersion = new Version(Versions.V2_4_0);
27+
var daemonWorker = new DaemonWorker(typedVersion);
28+
29+
Task.Run(() => daemonWorker.Run(daemonSocket));
30+
31+
for (int i = 0; i < taskRunnerNumber; ++i)
32+
{
33+
CreateAndVerifyConnection(daemonSocket);
34+
}
35+
36+
Assert.Equal(taskRunnerNumber, daemonWorker.CurrentNumTaskRunners);
37+
}
38+
39+
private static void CreateAndVerifyConnection(ISocketWrapper daemonSocket)
40+
{
41+
var ipEndpoint = (IPEndPoint)daemonSocket.LocalEndPoint;
42+
int port = ipEndpoint.Port;
43+
ISocketWrapper clientSocket = SocketFactory.CreateSocket();
44+
clientSocket.Connect(ipEndpoint.Address, port);
45+
46+
// Now process the bytes flowing in from the client.
47+
PayloadWriter payloadWriter = new PayloadWriterFactory().Create();
48+
payloadWriter.WriteTestData(clientSocket.OutputStream);
49+
List<object[]> rowsReceived = PayloadReader.Read(clientSocket.InputStream);
50+
51+
// Validate rows received.
52+
Assert.Equal(10, rowsReceived.Count);
53+
54+
for (int i = 0; i < 10; ++i)
55+
{
56+
// Two UDFs registered, thus expecting two columns.
57+
// Refer to TestData.GetDefaultCommandPayload().
58+
object[] row = rowsReceived[i];
59+
Assert.Equal(2, rowsReceived[i].Length);
60+
Assert.Equal($"udf2 udf1 {i}", row[0]);
61+
Assert.Equal(i + i, row[1]);
62+
}
63+
}
64+
}
65+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using System.Collections;
2+
using System.Collections.Generic;
3+
using System.IO;
4+
using Microsoft.Spark.Interop.Ipc;
5+
using Razorvine.Pickle;
6+
using Xunit;
7+
8+
namespace Microsoft.Spark.Worker.UnitTest
9+
{
10+
/// <summary>
11+
/// Payload reader that reads the output of the inputStream of the socket response
12+
/// </summary>
13+
internal sealed class PayloadReader
14+
{
15+
public static List<object[]> Read(Stream inputStream)
16+
{
17+
bool timingDataReceived = false;
18+
bool exceptionThrown = false;
19+
var rowsReceived = new List<object[]>();
20+
21+
while (true)
22+
{
23+
int length = SerDe.ReadInt32(inputStream);
24+
if (length > 0)
25+
{
26+
byte[] pickledBytes = SerDe.ReadBytes(inputStream, length);
27+
var unpickler = new Unpickler();
28+
29+
var rows = unpickler.loads(pickledBytes) as ArrayList;
30+
foreach (object row in rows)
31+
{
32+
rowsReceived.Add((object[]) row);
33+
}
34+
}
35+
else if (length == (int)SpecialLengths.TIMING_DATA)
36+
{
37+
long bootTime = SerDe.ReadInt64(inputStream);
38+
long initTime = SerDe.ReadInt64(inputStream);
39+
long finishTime = SerDe.ReadInt64(inputStream);
40+
long memoryBytesSpilled = SerDe.ReadInt64(inputStream);
41+
long diskBytesSpilled = SerDe.ReadInt64(inputStream);
42+
timingDataReceived = true;
43+
}
44+
else if (length == (int)SpecialLengths.PYTHON_EXCEPTION_THROWN)
45+
{
46+
SerDe.ReadString(inputStream);
47+
exceptionThrown = true;
48+
break;
49+
}
50+
else if (length == (int)SpecialLengths.END_OF_DATA_SECTION)
51+
{
52+
int numAccumulatorUpdates = SerDe.ReadInt32(inputStream);
53+
SerDe.ReadInt32(inputStream);
54+
break;
55+
}
56+
}
57+
58+
Assert.True(timingDataReceived);
59+
Assert.False(exceptionThrown);
60+
61+
return rowsReceived;
62+
}
63+
}
64+
}

src/csharp/Microsoft.Spark.Worker.UnitTest/PayloadWriter.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Linq;
1010
using Microsoft.Spark.Interop.Ipc;
1111
using Microsoft.Spark.Utils;
12+
using Razorvine.Pickle;
1213
using static Microsoft.Spark.Utils.UdfUtils;
1314

1415
namespace Microsoft.Spark.Worker.UnitTest
@@ -289,6 +290,29 @@ internal void Write(
289290
_commandWriter.Write(stream, commandPayload);
290291
}
291292

293+
public void WriteTestData(Stream stream)
294+
{
295+
Payload payload = TestData.GetDefaultPayload();
296+
CommandPayload commandPayload = TestData.GetDefaultCommandPayload();
297+
298+
Write(stream, payload, commandPayload);
299+
300+
// Write 10 rows to the output stream.
301+
var pickler = new Pickler();
302+
for (int i = 0; i < 10; ++i)
303+
{
304+
byte[] pickled = pickler.dumps(
305+
new[] { new object[] { i.ToString(), i, i } });
306+
SerDe.Write(stream, pickled.Length);
307+
SerDe.Write(stream, pickled);
308+
}
309+
310+
// Signal the end of data and stream.
311+
SerDe.Write(stream, (int)SpecialLengths.END_OF_DATA_SECTION);
312+
SerDe.Write(stream, (int)SpecialLengths.END_OF_STREAM);
313+
stream.Flush();
314+
}
315+
292316
private static void Write(Stream stream, IEnumerable<string> includeItems)
293317
{
294318
if (includeItems is null)

src/csharp/Microsoft.Spark.Worker.UnitTest/TaskRunnerTests.cs

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -34,69 +34,9 @@ public void TestTaskRunner()
3434
System.IO.Stream inputStream = serverSocket.InputStream;
3535
System.IO.Stream outputStream = serverSocket.OutputStream;
3636

37-
Payload payload = TestData.GetDefaultPayload();
38-
CommandPayload commandPayload = TestData.GetDefaultCommandPayload();
39-
40-
payloadWriter.Write(outputStream, payload, commandPayload);
41-
42-
// Write 10 rows to the output stream.
43-
var pickler = new Pickler();
44-
for (int i = 0; i < 10; ++i)
45-
{
46-
byte[] pickled = pickler.dumps(
47-
new[] { new object[] { i.ToString(), i, i } });
48-
SerDe.Write(outputStream, pickled.Length);
49-
SerDe.Write(outputStream, pickled);
50-
}
51-
52-
// Signal the end of data and stream.
53-
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_DATA_SECTION);
54-
SerDe.Write(outputStream, (int)SpecialLengths.END_OF_STREAM);
55-
outputStream.Flush();
56-
37+
payloadWriter.WriteTestData(outputStream);
5738
// Now process the bytes flowing in from the client.
58-
bool timingDataReceived = false;
59-
bool exceptionThrown = false;
60-
var rowsReceived = new List<object[]>();
61-
62-
while (true)
63-
{
64-
int length = SerDe.ReadInt32(inputStream);
65-
if (length > 0)
66-
{
67-
byte[] pickledBytes = SerDe.ReadBytes(inputStream, length);
68-
using var unpickler = new Unpickler();
69-
var rows = unpickler.loads(pickledBytes) as ArrayList;
70-
foreach (object row in rows)
71-
{
72-
rowsReceived.Add((object[])row);
73-
}
74-
}
75-
else if (length == (int)SpecialLengths.TIMING_DATA)
76-
{
77-
long bootTime = SerDe.ReadInt64(inputStream);
78-
long initTime = SerDe.ReadInt64(inputStream);
79-
long finishTime = SerDe.ReadInt64(inputStream);
80-
long memoryBytesSpilled = SerDe.ReadInt64(inputStream);
81-
long diskBytesSpilled = SerDe.ReadInt64(inputStream);
82-
timingDataReceived = true;
83-
}
84-
else if (length == (int)SpecialLengths.PYTHON_EXCEPTION_THROWN)
85-
{
86-
SerDe.ReadString(inputStream);
87-
exceptionThrown = true;
88-
break;
89-
}
90-
else if (length == (int)SpecialLengths.END_OF_DATA_SECTION)
91-
{
92-
int numAccumulatorUpdates = SerDe.ReadInt32(inputStream);
93-
SerDe.ReadInt32(inputStream);
94-
break;
95-
}
96-
}
97-
98-
Assert.True(timingDataReceived);
99-
Assert.False(exceptionThrown);
39+
var rowsReceived = PayloadReader.Read(inputStream);
10040

10141
// Validate rows received.
10242
Assert.Equal(10, rowsReceived.Count);

src/csharp/Microsoft.Spark.Worker/DaemonWorker.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ internal DaemonWorker(Version version)
5151
_version = version;
5252
}
5353

54+
internal int CurrentNumTaskRunners => _taskRunners.Count();
55+
5456
/// <summary>
5557
/// Runs the DaemonWorker server.
5658
/// </summary>
@@ -62,10 +64,14 @@ internal void Run()
6264
// AppDomain.CurrentDomain.ProcessExit += (s, e) => {};,
6365
// but the above handler is not invoked. This can be investigated if more
6466
// graceful exit is required.
67+
ISocketWrapper listener = SocketFactory.CreateSocket();
68+
Run(listener);
69+
}
6570

71+
internal void Run(ISocketWrapper listener)
72+
{
6673
try
6774
{
68-
ISocketWrapper listener = SocketFactory.CreateSocket();
6975
listener.Listen();
7076

7177
// Communicate the server port back to the Spark using standard output.
@@ -146,7 +152,8 @@ private void StartServer(ISocketWrapper listener)
146152
// When reuseWorker is set to true, numTaskRunners will be always one
147153
// greater than numWorkerThreads since TaskRunner.Run() does not return
148154
// so that the task runner object is not removed from _taskRunners.
149-
int numTaskRunners = _taskRunners.Count();
155+
int numTaskRunners = CurrentNumTaskRunners;
156+
150157
while (numWorkerThreads < numTaskRunners)
151158
{
152159
// Note that in the current implementation of RunWorkerThread() does

0 commit comments

Comments
 (0)