Skip to content

Commit 542e99d

Browse files
committed
Implement filtering for persistent services
1 parent 665edf1 commit 542e99d

File tree

3 files changed

+49
-25
lines changed

3 files changed

+49
-25
lines changed

src/Components/Components/src/PersistentComponentState.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ public class PersistentComponentState
1818

1919
private readonly List<PersistComponentStateRegistration> _registeredCallbacks;
2020
private readonly List<RestoreComponentStateRegistration> _registeredRestoringCallbacks;
21-
private RestoreContext? _currentContext;
2221

2322
internal PersistentComponentState(
2423
IDictionary<string, byte[]> currentState,
@@ -32,14 +31,16 @@ internal PersistentComponentState(
3231

3332
internal bool PersistingState { get; set; }
3433

34+
internal RestoreContext CurrentContext { get; private set; }
35+
3536
internal void InitializeExistingState(IDictionary<string, byte[]> existingState, RestoreContext context)
3637
{
3738
if (_existingState != null)
3839
{
3940
throw new InvalidOperationException("PersistentComponentState already initialized.");
4041
}
4142
_existingState = existingState ?? throw new ArgumentNullException(nameof(existingState));
42-
_currentContext = context;
43+
CurrentContext = context;
4344
}
4445

4546
/// <summary>
@@ -82,8 +83,8 @@ public PersistingComponentStateSubscription RegisterOnPersisting(Func<Task> call
8283
/// <returns>A subscription that can be used to unregister the callback when disposed.</returns>
8384
public RestoringComponentStateSubscription RegisterOnRestoring(Action callback, RestoreOptions options)
8485
{
85-
Debug.Assert(_currentContext != null);
86-
if (_currentContext.ShouldRestore(options))
86+
Debug.Assert(CurrentContext != null);
87+
if (CurrentContext.ShouldRestore(options))
8788
{
8889
callback();
8990
}
@@ -255,6 +256,6 @@ internal void UpdateExistingState(IDictionary<string, byte[]> state, RestoreCont
255256
}
256257

257258
_existingState = state;
258-
_currentContext = context;
259+
CurrentContext = context;
259260
}
260261
}

src/Components/Components/src/PersistentState/ComponentStatePersistenceManager.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public async Task RestoreStateAsync(IPersistentComponentStateStore store, Restor
7878
else
7979
{
8080
State.InitializeExistingState(data, context);
81-
_servicesRegistry?.Restore(State);
81+
_servicesRegistry?.RegisterForPersistence(State);
8282
_stateIsInitialized = true;
8383
}
8484

src/Components/Components/src/PersistentState/PersistentServicesRegistry.cs

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@ namespace Microsoft.AspNetCore.Components.Infrastructure;
1717
internal sealed class PersistentServicesRegistry
1818
{
1919
private static readonly string _registryKey = typeof(PersistentServicesRegistry).FullName!;
20-
private static readonly RootTypeCache _persistentServiceTypeCache = new RootTypeCache();
20+
private static readonly RootTypeCache _persistentServiceTypeCache = new();
2121

2222
private readonly IServiceProvider _serviceProvider;
2323
private IPersistentServiceRegistration[] _registrations;
24-
private List<PersistingComponentStateSubscription> _subscriptions = [];
24+
private List<(PersistingComponentStateSubscription, RestoringComponentStateSubscription)> _subscriptions = [];
2525
private static readonly ConcurrentDictionary<Type, PropertiesAccessor> _cachedAccessorsByType = new();
2626

2727
public PersistentServicesRegistry(IServiceProvider serviceProvider)
@@ -45,7 +45,8 @@ internal void RegisterForPersistence(PersistentComponentState state)
4545
return;
4646
}
4747

48-
var subscriptions = new List<PersistingComponentStateSubscription>(_registrations.Length + 1);
48+
var subscriptions = new List<(PersistingComponentStateSubscription, RestoringComponentStateSubscription)>(
49+
_registrations.Length + 1);
4950
for (var i = 0; i < _registrations.Length; i++)
5051
{
5152
var registration = _registrations[i];
@@ -58,20 +59,32 @@ internal void RegisterForPersistence(PersistentComponentState state)
5859
var renderMode = registration.GetRenderModeOrDefault();
5960

6061
var instance = _serviceProvider.GetRequiredService(type);
61-
subscriptions.Add(state.RegisterOnPersisting(() =>
62-
{
63-
PersistInstanceState(instance, type, state);
64-
return Task.CompletedTask;
65-
}, renderMode));
62+
subscriptions.Add((
63+
state.RegisterOnPersisting(() =>
64+
{
65+
PersistInstanceState(instance, type, state);
66+
return Task.CompletedTask;
67+
}, renderMode),
68+
// In order to avoid registering one callback per property, we register a single callback with the most
69+
// permissive options and perform the filtering inside of it.
70+
state.RegisterOnRestoring(() =>
71+
{
72+
RestoreInstanceState(instance, type, state);
73+
}, new RestoreOptions { AllowUpdates = true })));
6674
}
6775

6876
if (RenderMode != null)
6977
{
70-
subscriptions.Add(state.RegisterOnPersisting(() =>
71-
{
72-
state.PersistAsJson(_registryKey, _registrations);
73-
return Task.CompletedTask;
74-
}, RenderMode));
78+
subscriptions.Add((
79+
state.RegisterOnPersisting(() =>
80+
{
81+
state.PersistAsJson(_registryKey, _registrations);
82+
return Task.CompletedTask;
83+
}, RenderMode),
84+
state.RegisterOnRestoring(() =>
85+
{
86+
Restore(state);
87+
}, new RestoreOptions { RestoreBehavior = RestoreBehavior.SkipLastSnapshot })));
7588
}
7689

7790
_subscriptions = subscriptions;
@@ -83,7 +96,7 @@ private static void PersistInstanceState(object instance, Type type, PersistentC
8396
var accessors = _cachedAccessorsByType.GetOrAdd(instance.GetType(), static (runtimeType, declaredType) => new PropertiesAccessor(runtimeType, declaredType), type);
8497
foreach (var (key, propertyType) in accessors.KeyTypePairs)
8598
{
86-
var (setter, getter) = accessors.GetAccessor(key);
99+
var (setter, getter, options) = accessors.GetAccessor(key);
87100
var value = getter.GetValue(instance);
88101
if (value != null)
89102
{
@@ -131,9 +144,13 @@ private static void RestoreInstanceState(object instance, Type type, PersistentC
131144
var accessors = _cachedAccessorsByType.GetOrAdd(instance.GetType(), static (runtimeType, declaredType) => new PropertiesAccessor(runtimeType, declaredType), type);
132145
foreach (var (key, propertyType) in accessors.KeyTypePairs)
133146
{
147+
var (setter, getter, options) = accessors.GetAccessor(key);
148+
if (!state.CurrentContext.ShouldRestore(options))
149+
{
150+
continue;
151+
}
134152
if (state.TryTakeFromJson(key, propertyType, out var result))
135153
{
136-
var (setter, getter) = accessors.GetAccessor(key);
137154
setter.SetValue(instance, result!);
138155
}
139156
}
@@ -156,12 +173,12 @@ private sealed class PropertiesAccessor
156173
{
157174
internal const BindingFlags BindablePropertyFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.IgnoreCase;
158175

159-
private readonly Dictionary<string, (PropertySetter, PropertyGetter)> _underlyingAccessors;
176+
private readonly Dictionary<string, (PropertySetter, PropertyGetter, RestoreOptions)> _underlyingAccessors;
160177
private readonly (string, Type)[] _cachedKeysForService;
161178

162179
public PropertiesAccessor([DynamicallyAccessedMembers(LinkerFlags.Component)] Type targetType, Type keyType)
163180
{
164-
_underlyingAccessors = new Dictionary<string, (PropertySetter, PropertyGetter)>(StringComparer.OrdinalIgnoreCase);
181+
_underlyingAccessors = new Dictionary<string, (PropertySetter, PropertyGetter, RestoreOptions)>(StringComparer.OrdinalIgnoreCase);
165182

166183
var keys = new List<(string, Type)>();
167184
foreach (var propertyInfo in GetCandidateBindableProperties(targetType))
@@ -195,10 +212,16 @@ public PropertiesAccessor([DynamicallyAccessedMembers(LinkerFlags.Component)] Ty
195212
$"The type '{targetType.FullName}' declares a property matching the name '{propertyName}' that is not public. Persistent service properties must be public.");
196213
}
197214

215+
var restoreOptions = new RestoreOptions
216+
{
217+
RestoreBehavior = parameterAttribute.RestoreBehavior,
218+
AllowUpdates = parameterAttribute.AllowUpdates,
219+
};
220+
198221
var propertySetter = new PropertySetter(targetType, propertyInfo);
199222
var propertyGetter = new PropertyGetter(targetType, propertyInfo);
200223

201-
_underlyingAccessors.Add(key, (propertySetter, propertyGetter));
224+
_underlyingAccessors.Add(key, (propertySetter, propertyGetter, restoreOptions));
202225
}
203226

204227
_cachedKeysForService = [.. keys];
@@ -227,7 +250,7 @@ internal static IEnumerable<PropertyInfo> GetCandidateBindableProperties(
227250
[DynamicallyAccessedMembers(LinkerFlags.Component)] Type targetType)
228251
=> MemberAssignment.GetPropertiesIncludingInherited(targetType, BindablePropertyFlags);
229252

230-
internal (PropertySetter setter, PropertyGetter getter) GetAccessor(string key) =>
253+
internal (PropertySetter setter, PropertyGetter getter, RestoreOptions options) GetAccessor(string key) =>
231254
_underlyingAccessors.TryGetValue(key, out var result) ? result : default;
232255
}
233256

0 commit comments

Comments
 (0)