Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions SteamKit2/SteamKit2/Steam/SteamClient/CallbackMgr/CallbackMgr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void RunWaitCallbacks()
public async Task RunWaitCallbackAsync( CancellationToken cancellationToken = default )
{
var call = await client.WaitForCallbackAsync( cancellationToken ).ConfigureAwait( false );
Handle( call );
await HandleAsync( call ).ConfigureAwait( false );
}

/// <summary>
Expand Down Expand Up @@ -139,6 +139,36 @@ public IDisposable Subscribe<TCallback>( Action<TCallback> callbackFunc )
return Subscribe( JobID.Invalid, callbackFunc );
}

/// <summary>
/// Registers the provided <see cref="Func{T, Task}"/> to receive callbacks of type <typeparamref name="TCallback" />.
/// </summary>
/// <param name="jobID">The <see cref="JobID"/> of the callbacks that should be subscribed to.
/// If this is <see cref="JobID.Invalid"/>, all callbacks of type <typeparamref name="TCallback" /> will be received.</param>
/// <param name="callbackFunc">The function to invoke with the callback.</param>
/// <typeparam name="TCallback">The type of callback to subscribe to.</typeparam>
/// <remarks>When subscribing to asynchronous methods, <see cref="RunWaitCallbackAsync"/> should be used for awaiting callbacks.</remarks>
/// <returns>An <see cref="IDisposable"/>. Disposing of the return value will unsubscribe the <paramref name="callbackFunc"/>.</returns>
public IDisposable Subscribe<TCallback>( JobID jobID, Func<TCallback, Task> callbackFunc ) where TCallback : CallbackMsg
{
ArgumentNullException.ThrowIfNull( jobID );
ArgumentNullException.ThrowIfNull( callbackFunc );

var callback = new Internal.AsyncCallback<TCallback>( callbackFunc, this, jobID );
return callback;
}

/// <summary>
/// Registers the provided <see cref="Func{T, Task}"/> to receive callbacks of type <typeparam name="TCallback" />.
/// </summary>
/// <param name="callbackFunc">The function to invoke with the callback.</param>
/// <remarks>When subscribing to asynchronous methods, <see cref="RunWaitCallbackAsync"/> should be used for awaiting callbacks.</remarks>
/// <returns>An <see cref="IDisposable"/>. Disposing of the return value will unsubscribe the <paramref name="callbackFunc"/>.</returns>
public IDisposable Subscribe<TCallback>( Func<TCallback, Task> callbackFunc )
where TCallback : CallbackMsg
{
return Subscribe( JobID.Invalid, callbackFunc );
}

/// <summary>
/// Registers the provided <see cref="Action{T}"/> to receive callbacks for notifications from the service of type <typeparam name="TService" />
/// with the notification message of type <typeparam name="TNotification"></typeparam>.
Expand Down Expand Up @@ -191,7 +221,28 @@ void Handle( CallbackMsg call )
{
if ( callback.CallbackType.IsAssignableFrom( type ) )
{
callback.Run( call );
var task = callback.Run( call );
task?.Wait();
}
}
}

async Task HandleAsync( CallbackMsg call )
{
var callbacks = registeredCallbacks;
var type = call.GetType();

// find handlers interested in this callback
foreach ( var callback in callbacks )
{
if ( callback.CallbackType.IsAssignableFrom( type ) )
{
var task = callback.Run( call );

if ( task != null )
{
await task.ConfigureAwait( false );
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


using System;
using System.Threading.Tasks;

namespace SteamKit2.Internal
{
Expand All @@ -16,22 +17,24 @@ namespace SteamKit2.Internal
abstract class CallbackBase
{
internal abstract Type CallbackType { get; }
internal abstract void Run( CallbackMsg callback );
internal abstract Task? Run( CallbackMsg callback );
}

sealed class Callback<TCall> : CallbackBase, IDisposable
where TCall : CallbackMsg
{
CallbackManager? mgr;

public JobID JobID { get; set; }
public JobID JobID { get; }

public Action<TCall> OnRun { get; set; }
public Action<TCall> OnRun { get; }

internal override Type CallbackType => typeof( TCall );

public Callback( Action<TCall> func, CallbackManager mgr, JobID jobID )
{
ArgumentNullException.ThrowIfNull( func );

this.JobID = jobID;
this.OnRun = func;
this.mgr = mgr;
Expand All @@ -52,13 +55,60 @@ public void Dispose()
System.GC.SuppressFinalize( this );
}

internal override void Run( CallbackMsg callback )
internal override Task? Run( CallbackMsg callback )
{
var cb = callback as TCall;
if ( cb != null && ( cb.JobID == JobID || JobID == JobID.Invalid ) && OnRun != null )
if ( cb != null && ( cb.JobID == JobID || JobID == JobID.Invalid ) )
{
OnRun( cb );
}
return null;
}
}

sealed class AsyncCallback<TCall> : CallbackBase, IDisposable
where TCall : CallbackMsg
{
CallbackManager? mgr;

public JobID JobID { get; }

public Func<TCall, Task> OnRun { get; }

internal override Type CallbackType => typeof( TCall );

public AsyncCallback( Func<TCall, Task> func, CallbackManager mgr, JobID jobID )
{
ArgumentNullException.ThrowIfNull( func );

this.JobID = jobID;
this.OnRun = func;
this.mgr = mgr;

mgr.Register( this );
}

~AsyncCallback()
{
Dispose();
}

public void Dispose()
{
mgr?.Unregister( this );
mgr = null;

System.GC.SuppressFinalize( this );
}

internal override Task? Run( CallbackMsg callback )
{
var cb = callback as TCall;
if ( cb != null && ( cb.JobID == JobID || JobID == JobID.Invalid ) )
{
return OnRun( cb );
}
return null;
}
}
}
34 changes: 33 additions & 1 deletion SteamKit2/Tests/CallbackManagerFacts.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using SteamKit2;
using Xunit;
Expand Down Expand Up @@ -257,7 +258,7 @@ void unsubscribe( CallbackForTest cb )
}

[Fact]
public void CorrectlysubscribesFromInsideOfCallback()
public void CorrectlySubscribesFromInsideOfCallback()
{
static void nothing( CallbackForTest cb )
{
Expand All @@ -275,6 +276,37 @@ void subscribe( CallbackForTest cb )
PostAndRunCallback( new CallbackForTest { UniqueID = Guid.NewGuid() } );
}

[Fact]
public async Task CorrectlyAwaitsForAsyncCallbacks()
{
var callback = new CallbackForTest { UniqueID = Guid.NewGuid() };

var numCallbacksRun = 0;
async Task action( CallbackForTest cb )
{
await Task.Delay( 100, TestContext.Current.CancellationToken );
Assert.Equal( callback.UniqueID, cb.UniqueID );
numCallbacksRun++;
}

using ( mgr.Subscribe<CallbackForTest>( action ) )
{
for ( var i = 0; i < 10; i++ )
{
client.PostCallback( callback );
}

for ( var i = 1; i <= 10; i++ )
{
await mgr.RunWaitCallbackAsync( TestContext.Current.CancellationToken );
Assert.Equal( i, numCallbacksRun );
}

mgr.RunWaitAllCallbacks( TimeSpan.Zero );
Assert.Equal( 10, numCallbacksRun );
}
}

void PostAndRunCallback(CallbackMsg callback)
{
client.PostCallback(callback);
Expand Down
Loading