@@ -73,6 +73,7 @@ internal class GrpcWorkerChannel : IRpcWorkerChannel, IDisposable
73
73
private TaskCompletionSource < List < RawFunctionMetadata > > _functionsIndexingTask = new TaskCompletionSource < List < RawFunctionMetadata > > ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
74
74
private TimeSpan _functionLoadTimeout = TimeSpan . FromMinutes ( 1 ) ;
75
75
private bool _isSharedMemoryDataTransferEnabled ;
76
+ private bool _cancelCapabilityEnabled ;
76
77
77
78
private object _syncLock = new object ( ) ;
78
79
private System . Timers . Timer _timer ;
@@ -275,6 +276,7 @@ internal void WorkerInitResponse(GrpcEvent initEvent)
275
276
_state = _state | RpcWorkerChannelState . Initialized ;
276
277
_workerCapabilities . UpdateCapabilities ( _initMessage . Capabilities ) ;
277
278
_isSharedMemoryDataTransferEnabled = IsSharedMemoryDataTransferEnabled ( ) ;
279
+ _cancelCapabilityEnabled = ! string . IsNullOrEmpty ( _workerCapabilities . GetCapabilityState ( RpcWorkerConstants . HandlesInvocationCancelMessage ) ) ;
278
280
279
281
if ( ! _isSharedMemoryDataTransferEnabled )
280
282
{
@@ -501,36 +503,58 @@ internal async Task SendInvocationRequest(ScriptInvocationContext context)
501
503
_workerChannelLogger . LogDebug ( $ "Function { context . FunctionMetadata . Name } failed to load") ;
502
504
context . ResultSource . TrySetException ( _functionLoadErrors [ context . FunctionMetadata . GetFunctionId ( ) ] ) ;
503
505
_executingInvocations . TryRemove ( context . ExecutionContext . InvocationId . ToString ( ) , out ScriptInvocationContext _ ) ;
506
+ return ;
504
507
}
505
508
else if ( _metadataRequestErrors . ContainsKey ( context . FunctionMetadata . GetFunctionId ( ) ) )
506
509
{
507
510
_workerChannelLogger . LogDebug ( $ "Worker failed to load metadata for { context . FunctionMetadata . Name } ") ;
508
511
context . ResultSource . TrySetException ( _metadataRequestErrors [ context . FunctionMetadata . GetFunctionId ( ) ] ) ;
509
512
_executingInvocations . TryRemove ( context . ExecutionContext . InvocationId . ToString ( ) , out ScriptInvocationContext _ ) ;
513
+ return ;
510
514
}
511
- else
515
+
516
+ if ( context . CancellationToken . IsCancellationRequested )
512
517
{
513
- if ( context . CancellationToken . IsCancellationRequested )
514
- {
515
- context . ResultSource . SetCanceled ( ) ;
516
- return ;
517
- }
518
- var invocationRequest = await context . ToRpcInvocationRequest ( _workerChannelLogger , _workerCapabilities , _isSharedMemoryDataTransferEnabled , _sharedMemoryManager ) ;
519
- AddAdditionalTraceContext ( invocationRequest . TraceContext . Attributes , context ) ;
520
- _executingInvocations . TryAdd ( invocationRequest . InvocationId , context ) ;
518
+ _workerChannelLogger . LogDebug ( "Cancellation has been requested, cancelling invocation request" ) ;
519
+ context . ResultSource . SetCanceled ( ) ;
520
+ return ;
521
+ }
521
522
522
- SendStreamingMessage ( new StreamingMessage
523
- {
524
- InvocationRequest = invocationRequest
525
- } ) ;
523
+ var invocationRequest = await context . ToRpcInvocationRequest ( _workerChannelLogger , _workerCapabilities , _isSharedMemoryDataTransferEnabled , _sharedMemoryManager ) ;
524
+ AddAdditionalTraceContext ( invocationRequest . TraceContext . Attributes , context ) ;
525
+ _executingInvocations . TryAdd ( invocationRequest . InvocationId , context ) ;
526
+
527
+ if ( _cancelCapabilityEnabled )
528
+ {
529
+ context . CancellationToken . Register ( ( ) => SendInvocationCancel ( invocationRequest . InvocationId ) ) ;
526
530
}
531
+
532
+ SendStreamingMessage ( new StreamingMessage
533
+ {
534
+ InvocationRequest = invocationRequest
535
+ } ) ;
527
536
}
528
537
catch ( Exception invokeEx )
529
538
{
530
539
context . ResultSource . TrySetException ( invokeEx ) ;
531
540
}
532
541
}
533
542
543
+ internal void SendInvocationCancel ( string invocationId )
544
+ {
545
+ _workerChannelLogger . LogDebug ( $ "Sending invocation cancel request for InvocationId { invocationId } ") ;
546
+
547
+ var invocationCancel = new InvocationCancel
548
+ {
549
+ InvocationId = invocationId
550
+ } ;
551
+
552
+ SendStreamingMessage ( new StreamingMessage
553
+ {
554
+ InvocationCancel = invocationCancel
555
+ } ) ;
556
+ }
557
+
534
558
// gets metadata from worker
535
559
public Task < List < RawFunctionMetadata > > GetFunctionMetadata ( )
536
560
{
0 commit comments