Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Mediator.SourceGenerator.Extensions;
using Mediator.SourceGenerator.Extensions;

namespace Mediator.SourceGenerator;

Expand All @@ -19,8 +19,10 @@ public NotificationMessageHandlerModel(NotificationMessageHandler handler, Compi

if (!handler.Symbol.IsGenericType)
{
var concreteRegistration =
$"services.TryAdd(new {sd}(typeof({concreteSymbol}), typeof({concreteSymbol}), {analyzer.ServiceLifetime}));";
var concreteRegistration = $"""
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof({concreteSymbol}), typeof({concreteSymbol})))
services.TryAdd(new {sd}(typeof({concreteSymbol}), typeof({concreteSymbol}), {analyzer.ServiceLifetime}));
""";
builder.Add(concreteRegistration);
}

Expand All @@ -29,14 +31,19 @@ public NotificationMessageHandlerModel(NotificationMessageHandler handler, Compi
var requestType = message.Symbol.GetTypeSymbolFullName();
if (handler.Symbol.IsGenericType)
{
var concreteRegistration =
$"services.TryAdd(new {sd}(typeof({concreteSymbol}<{requestType}>), typeof({concreteSymbol}<{requestType}>), {analyzer.ServiceLifetime}));";
var concreteRegistration = $"""
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof({concreteSymbol}<{requestType}>), typeof({concreteSymbol}<{requestType}>)))
services.TryAdd(new {sd}(typeof({concreteSymbol}<{requestType}>), typeof({concreteSymbol}<{requestType}>), {analyzer.ServiceLifetime}));
""";
builder.Add(concreteRegistration);
}
var getExpression =
$"GetRequiredService<{concreteSymbol}{(handler.Symbol.IsGenericType ? $"<{requestType}>" : "")}>()";
var registration =
$"services.Add(new {sd}(typeof({interfaceSymbol}<{requestType}>), {getExpression}, {analyzer.ServiceLifetime}));";

var concreteImpl = $"{concreteSymbol}{(handler.Symbol.IsGenericType ? $"<{requestType}>" : "")}";
var getExpression = $"GetRequiredService<{concreteImpl}>()";
var registration = $"""
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof({interfaceSymbol}<{requestType}>), typeof({concreteImpl})))
services.Add(new {sd}(typeof({interfaceSymbol}<{requestType}>), {getExpression}, {analyzer.ServiceLifetime}));
""";
builder.Add(registration);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ namespace Microsoft.Extensions.DependencyInjection
throw new global::System.Exception(errMsg);
}

// Build cache of existing registrations for efficient lookup
var existingRegistrations = BuildRegistrationCache(services);

{{~ if ServiceLifetimeIsTransient || ServiceLifetimeIsScoped ~}}
services.Add(new {{ SD }}(typeof(global::{{ MediatorNamespace }}.Mediator), typeof(global::{{ MediatorNamespace }}.Mediator), {{ ServiceLifetime }}));
services.TryAdd(new {{ SD }}(typeof(global::Mediator.IMediator), typeof(global::{{ MediatorNamespace }}.Mediator), {{ ServiceLifetime }}));
Expand All @@ -61,7 +64,7 @@ namespace Microsoft.Extensions.DependencyInjection
services.TryAdd(new {{ SD }}(typeof(global::Mediator.IPublisher), sp => sp.GetRequiredService<global::{{ MediatorNamespace }}.Mediator>(), {{ ServiceLifetime }}));
{{~ end ~}}
{{~ if (object.size RequestMessages) > 0 ~}}

// Register handlers for request messages
{{~ for message in RequestMessages ~}}
{{ message.Handler.ServiceRegistration }}
Expand All @@ -70,14 +73,14 @@ namespace Microsoft.Extensions.DependencyInjection
{{~ end ~}}
{{~ end ~}}
{{~ if (object.size NotificationMessages) > 0 ~}}

// Register handlers and wrappers for notification messages
{{~ for message in NotificationMessages ~}}
services.Add(new {{ SD }}(typeof({{ message.HandlerWrapperTypeNameWithGenericTypeArguments }}), typeof({{ message.HandlerWrapperTypeNameWithGenericTypeArguments }}), {{ SingletonServiceLifetime }}));
{{~ end ~}}
{{~ end ~}}
{{~ if (object.size NotificationMessageHandlers) > 0 ~}}

// Register notification handlers
{{~ for handler in NotificationMessageHandlers ~}}
{{~ for registration in handler.ServiceRegistrations ~}}
Expand All @@ -86,15 +89,15 @@ namespace Microsoft.Extensions.DependencyInjection
{{~ end ~}}
{{~ end ~}}
{{~ if (object.size PipelineBehaviors) > 0 ~}}

// Register pipeline behaviors configured through options
{{~ for behavior in PipelineBehaviors ~}}
{{~ for registration in behavior.ServiceRegistrations ~}}
{{ registration }}
{{~ end ~}}
{{~ end ~}}
{{~ end ~}}

// Register the notification publisher that was configured
{{~ if ServiceLifetimeIsScoped || ServiceLifetimeIsTransient ~}}
services.Add(new {{ SD }}(typeof({{ NotificationPublisherType.FullName }}), typeof({{ NotificationPublisherType.FullName }}), {{ ServiceLifetime }}));
Expand All @@ -103,18 +106,68 @@ namespace Microsoft.Extensions.DependencyInjection
services.Add(new {{ SD }}(typeof({{ NotificationPublisherType.FullName }}), typeof({{ NotificationPublisherType.FullName }}), {{ SingletonServiceLifetime }}));
services.TryAdd(new {{ SD }}(typeof(global::Mediator.INotificationPublisher), sp => sp.GetRequiredService<{{ NotificationPublisherType.FullName }}>(), {{ SingletonServiceLifetime }}));
{{~ end ~}}

// Register internal components
services.Add(new {{ SD }}(typeof(global::{{ InternalsNamespace }}.IContainerProbe), typeof(global::{{ InternalsNamespace }}.ContainerProbe0), {{ ServiceLifetime }}));
services.Add(new {{ SD }}(typeof(global::{{ InternalsNamespace }}.IContainerProbe), typeof(global::{{ InternalsNamespace }}.ContainerProbe1), {{ ServiceLifetime }}));
services.Add(new {{ SD }}(typeof(global::{{ InternalsNamespace }}.ContainerMetadata), typeof(global::{{ InternalsNamespace }}.ContainerMetadata), {{ SingletonServiceLifetime }}));

return services;

{{~ if HasNotifications ~}}
{{~ if HasNotifications ~}}
[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
static global::System.Func<global::System.IServiceProvider, T> GetRequiredService<T>() where T : notnull => sp => sp.GetRequiredService<T>();
{{~ end ~}}
{{~ end ~}}
}

/// <summary>
/// Builds a cache of existing service registrations for efficient duplicate detection.
/// Maps service types to their registered implementation types.
/// </summary>
/// <param name="services">The service collection to analyze</param>
/// <returns>Dictionary mapping service types to sets of implementation types</returns>
private static global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>>
BuildRegistrationCache(IServiceCollection services)
{
var cache = new global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>>();

foreach (var service in services)
{
if (service.ServiceType == null) continue;

if (!cache.ContainsKey(service.ServiceType))
{
cache[service.ServiceType] = new global::System.Collections.Generic.HashSet<global::System.Type>();
}

// Handle different ServiceDescriptor registration patterns
if (service.ImplementationType != null)
{
cache[service.ServiceType].Add(service.ImplementationType);
}
else if (service.ImplementationInstance != null)
{
cache[service.ServiceType].Add(service.ImplementationInstance.GetType());
}
}

return cache;
}

/// <summary>
/// Checks if a handler registration already exists in the service collection.
/// </summary>
/// <param name="existingRegistrations">Cache of existing registrations</param>
/// <param name="serviceType">The service interface type</param>
/// <param name="implementationType">The concrete implementation type</param>
/// <returns>True if the handler is already registered</returns>
private static bool IsHandlerAlreadyRegistered(
global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>> existingRegistrations,
global::System.Type serviceType,
global::System.Type implementationType)
{
return existingRegistrations.ContainsKey(serviceType) &&
existingRegistrations[serviceType].Contains(implementationType);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,34 +50,90 @@ public static IServiceCollection AddMediator(this IServiceCollection services, g
throw new global::System.Exception(errMsg);
}

// Build cache of existing registrations for efficient lookup
var existingRegistrations = BuildRegistrationCache(services);

services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Mediator), typeof(global::Mediator.Mediator), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.IMediator), sp => sp.GetRequiredService<global::Mediator.Mediator>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.ISender), sp => sp.GetRequiredService<global::Mediator.Mediator>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.IPublisher), sp => sp.GetRequiredService<global::Mediator.Mediator>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));

// Register handlers and wrappers for notification messages
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.NotificationHandlerWrapper<global::TestCode.Notification0>), typeof(global::Mediator.Internals.NotificationHandlerWrapper<global::TestCode.Notification0>), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.NotificationHandlerWrapper<global::TestCode.Notification1>), typeof(global::Mediator.Internals.NotificationHandlerWrapper<global::TestCode.Notification1>), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));

// Register notification handlers
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::TestCode.RequestHandler), typeof(global::TestCode.RequestHandler), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationHandler<global::TestCode.Notification0>), GetRequiredService<global::TestCode.RequestHandler>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationHandler<global::TestCode.Notification1>), GetRequiredService<global::TestCode.RequestHandler>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));

if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof(global::TestCode.RequestHandler), typeof(global::TestCode.RequestHandler)))
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::TestCode.RequestHandler), typeof(global::TestCode.RequestHandler), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof(global::Mediator.INotificationHandler<global::TestCode.Notification0>), typeof(global::TestCode.RequestHandler)))
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationHandler<global::TestCode.Notification0>), GetRequiredService<global::TestCode.RequestHandler>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
if (!IsHandlerAlreadyRegistered(existingRegistrations, typeof(global::Mediator.INotificationHandler<global::TestCode.Notification1>), typeof(global::TestCode.RequestHandler)))
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationHandler<global::TestCode.Notification1>), GetRequiredService<global::TestCode.RequestHandler>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));

// Register the notification publisher that was configured
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.ForeachAwaitPublisher), typeof(global::Mediator.ForeachAwaitPublisher), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.TryAdd(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.INotificationPublisher), sp => sp.GetRequiredService<global::Mediator.ForeachAwaitPublisher>(), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));

// Register internal components
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.IContainerProbe), typeof(global::Mediator.Internals.ContainerProbe0), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.IContainerProbe), typeof(global::Mediator.Internals.ContainerProbe1), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));
services.Add(new global::Microsoft.Extensions.DependencyInjection.ServiceDescriptor(typeof(global::Mediator.Internals.ContainerMetadata), typeof(global::Mediator.Internals.ContainerMetadata), global::Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton));

return services;

[global::System.Runtime.CompilerServices.MethodImpl(global::System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]
static global::System.Func<global::System.IServiceProvider, T> GetRequiredService<T>() where T : notnull => sp => sp.GetRequiredService<T>();
}

/// <summary>
/// Builds a cache of existing service registrations for efficient duplicate detection.
/// Maps service types to their registered implementation types.
/// </summary>
/// <param name="services">The service collection to analyze</param>
/// <returns>Dictionary mapping service types to sets of implementation types</returns>
private static global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>>
BuildRegistrationCache(IServiceCollection services)
{
var cache = new global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>>();

foreach (var service in services)
{
if (service.ServiceType == null) continue;

if (!cache.ContainsKey(service.ServiceType))
{
cache[service.ServiceType] = new global::System.Collections.Generic.HashSet<global::System.Type>();
}

// Handle different ServiceDescriptor registration patterns
if (service.ImplementationType != null)
{
cache[service.ServiceType].Add(service.ImplementationType);
}
else if (service.ImplementationInstance != null)
{
cache[service.ServiceType].Add(service.ImplementationInstance.GetType());
}
}

return cache;
}

/// <summary>
/// Checks if a handler registration already exists in the service collection.
/// </summary>
/// <param name="existingRegistrations">Cache of existing registrations</param>
/// <param name="serviceType">The service interface type</param>
/// <param name="implementationType">The concrete implementation type</param>
/// <returns>True if the handler is already registered</returns>
private static bool IsHandlerAlreadyRegistered(
global::System.Collections.Generic.Dictionary<global::System.Type, global::System.Collections.Generic.HashSet<global::System.Type>> existingRegistrations,
global::System.Type serviceType,
global::System.Type implementationType)
{
return existingRegistrations.ContainsKey(serviceType) &&
existingRegistrations[serviceType].Contains(implementationType);
}
}
}

Expand Down
Loading