33using System . Reflection ;
44using System . Runtime . CompilerServices ;
55using System . Runtime . ExceptionServices ;
6+ using AngleSharp . Dom ;
67using Microsoft . Extensions . Logging ;
78
89namespace Bunit . Rendering ;
@@ -26,6 +27,7 @@ public sealed class BunitRenderer : Renderer
2627
2728 private readonly HashSet < int > returnedRenderedComponentIds = new ( ) ;
2829 private readonly List < BunitRootComponent > rootComponents = new ( ) ;
30+ private readonly Dictionary < string , int > elementReferenceToComponentId = new ( ) ;
2931 private readonly ILogger < BunitRenderer > logger ;
3032 private bool disposed ;
3133 private TaskCompletionSource < Exception > unhandledExceptionTsc = new ( TaskCreationOptions . RunContinuationsAsynchronously ) ;
@@ -453,6 +455,7 @@ protected override Task UpdateDisplayAsync(in RenderBatch renderBatch)
453455 var id = renderBatch . DisposedComponentIDs . Array [ i ] ;
454456 disposedComponentIds . Add ( id ) ;
455457 returnedRenderedComponentIds . Remove ( id ) ;
458+ RemoveElementReferencesForComponent ( id ) ;
456459 }
457460
458461 for ( var i = 0 ; i < renderBatch . UpdatedComponents . Count ; i ++ )
@@ -467,6 +470,8 @@ protected override Task UpdateDisplayAsync(in RenderBatch renderBatch)
467470 var componentState = GetComponentState ( diff . ComponentId ) ;
468471 var renderedComponent = ( IRenderedComponent ) componentState ;
469472
473+ TrackElementReferencesForComponent ( diff . ComponentId ) ;
474+
470475 if ( returnedRenderedComponentIds . Contains ( diff . ComponentId ) )
471476 {
472477 renderedComponent . UpdateState ( hasRendered : true , isMarkupGenerationRequired : diff . Edits . Count > 0 ) ;
@@ -519,6 +524,101 @@ static bool IsParentComponentAlreadyUpdated(int componentId, in RenderBatch rend
519524 }
520525 }
521526
527+ private void TrackElementReferencesForComponent ( int componentId )
528+ {
529+ var frames = GetCurrentRenderTreeFrames ( componentId ) ;
530+ TrackElementReferencesInFrames ( frames , componentId ) ;
531+ }
532+
533+ private void TrackElementReferencesInFrames ( ArrayRange < RenderTreeFrame > frames , int componentId )
534+ {
535+ for ( var i = 0 ; i < frames . Count ; i ++ )
536+ {
537+ ref var frame = ref frames . Array [ i ] ;
538+
539+ if ( frame . FrameType == RenderTreeFrameType . ElementReferenceCapture )
540+ {
541+ var elementReferenceId = frame . ElementReferenceCaptureId ;
542+ if ( elementReferenceId != null )
543+ {
544+ elementReferenceToComponentId [ elementReferenceId ] = componentId ;
545+ }
546+ }
547+ else if ( frame . FrameType == RenderTreeFrameType . Component )
548+ {
549+ TrackElementReferencesForComponent ( frame . ComponentId ) ;
550+ }
551+ }
552+ }
553+
554+ private void RemoveElementReferencesForComponent ( int componentId )
555+ {
556+ var keysToRemove = elementReferenceToComponentId
557+ . Where ( kvp => kvp . Value == componentId )
558+ . Select ( kvp => kvp . Key )
559+ . ToList ( ) ;
560+
561+ foreach ( var key in keysToRemove )
562+ {
563+ elementReferenceToComponentId . Remove ( key ) ;
564+ }
565+ }
566+
567+ internal IRenderedComponent < TComponent > ? FindComponentForElement < TComponent > ( IElement element )
568+ where TComponent : IComponent
569+ {
570+ var elementReferenceId = element . GetAttribute ( "blazor:elementReference" ) ;
571+ if ( elementReferenceId is not null && elementReferenceToComponentId . TryGetValue ( elementReferenceId , out var componentId ) )
572+ {
573+ return GetRenderedComponent < TComponent > ( componentId ) ;
574+ }
575+
576+ return FindComponentByElementContainment < TComponent > ( element ) ;
577+ }
578+
579+ private IRenderedComponent < TComponent > ? FindComponentByElementContainment < TComponent > ( IElement element )
580+ where TComponent : IComponent
581+ {
582+ List < int > renderedComponentIdsWhenStarted = [ ..returnedRenderedComponentIds ] ;
583+ var components = new List < IRenderedComponent < TComponent > > ( returnedRenderedComponentIds . Count ) ;
584+
585+ foreach ( var parentComponent in renderedComponentIdsWhenStarted . Select ( GetRenderedComponent < IComponent > ) )
586+ {
587+ components . AddRange ( FindComponents < TComponent > ( parentComponent ) ) ;
588+ }
589+
590+ return components . FirstOrDefault ( component => ComponentContainsElement ( component , element ) ) ;
591+ }
592+
593+ private static bool ComponentContainsElement < TComponent > ( IRenderedComponent < TComponent > component , IElement element )
594+ where TComponent : IComponent
595+ {
596+ foreach ( var node in component . Nodes )
597+ {
598+ if ( node is IElement nodeElement && nodeElement . Equals ( element ) )
599+ {
600+ return true ;
601+ }
602+ if ( IsDescendantOf ( element , node ) )
603+ {
604+ return true ;
605+ }
606+ }
607+ return false ;
608+ }
609+
610+ private static bool IsDescendantOf ( IElement element , INode potentialAncestor )
611+ {
612+ var current = element . Parent ;
613+ while ( current is not null )
614+ {
615+ if ( current == potentialAncestor )
616+ return true ;
617+ current = current . Parent ;
618+ }
619+ return false ;
620+ }
621+
522622 /// <inheritdoc/>
523623 internal new ArrayRange < RenderTreeFrame > GetCurrentRenderTreeFrames ( int componentId )
524624 => base . GetCurrentRenderTreeFrames ( componentId ) ;
0 commit comments