|  | 
|  | 1 | +using System.Security.Claims; | 
|  | 2 | +using System.Web; | 
|  | 3 | +using EssentialCSharp.Web.Areas.Identity.Data; | 
|  | 4 | +using EssentialCSharp.Web.Services; | 
|  | 5 | +using Microsoft.AspNetCore.Http.Extensions; | 
|  | 6 | +using Microsoft.AspNetCore.Identity; | 
|  | 7 | + | 
|  | 8 | +namespace EssentialCSharp.Web.Middleware; | 
|  | 9 | + | 
|  | 10 | +public class ReferralMiddleware | 
|  | 11 | +{ | 
|  | 12 | +    private readonly RequestDelegate _Next; | 
|  | 13 | + | 
|  | 14 | +    public ReferralMiddleware(RequestDelegate next) | 
|  | 15 | +    { | 
|  | 16 | +        _Next = next; | 
|  | 17 | +    } | 
|  | 18 | + | 
|  | 19 | +    public async Task InvokeAsync(HttpContext context, IReferralService referralService, UserManager<EssentialCSharpWebUser> userManager) | 
|  | 20 | +    { | 
|  | 21 | +        // Retrieve current referral Id for processing | 
|  | 22 | +        System.Collections.Specialized.NameValueCollection query = HttpUtility.ParseQueryString(context.Request.QueryString.Value!); | 
|  | 23 | +        string? referralId = query["rid"]; | 
|  | 24 | +        string? userReferralId; | 
|  | 25 | + | 
|  | 26 | +        if (context.User is { } claimsUser && claimsUser.Identity is not null && claimsUser.Identity.IsAuthenticated) | 
|  | 27 | +        { | 
|  | 28 | +            if (!string.IsNullOrWhiteSpace(referralId)) | 
|  | 29 | +            { | 
|  | 30 | +                await TrackReferralAsync(referralService, referralId, claimsUser); | 
|  | 31 | +            } | 
|  | 32 | + | 
|  | 33 | +            // Add the referralId to the request context if it exists on a user | 
|  | 34 | +            EssentialCSharpWebUser? user = await userManager.GetUserAsync(claimsUser); | 
|  | 35 | +            if (user is not null) | 
|  | 36 | +            { | 
|  | 37 | +                userReferralId = await referralService.GetReferralIdAsync(user.Id); | 
|  | 38 | + | 
|  | 39 | +                if (!string.IsNullOrWhiteSpace(userReferralId) && (string.IsNullOrWhiteSpace(query["rid"]) || query["rid"] != userReferralId)) | 
|  | 40 | +                { | 
|  | 41 | +                    query.Remove("rid"); | 
|  | 42 | +                    query.Add("rid", userReferralId); | 
|  | 43 | +                    var builder = new UriBuilder(context.Request.GetEncodedUrl()) | 
|  | 44 | +                    { | 
|  | 45 | +                        Query = query.ToString() | 
|  | 46 | +                    }; | 
|  | 47 | +                    context.Response.Redirect(builder.ToString()); | 
|  | 48 | +                    return; | 
|  | 49 | +                } | 
|  | 50 | +            } | 
|  | 51 | +        } | 
|  | 52 | +        else | 
|  | 53 | +        { | 
|  | 54 | + | 
|  | 55 | +            if (!string.IsNullOrWhiteSpace(referralId)) | 
|  | 56 | +            { | 
|  | 57 | +                await TrackReferralAsync(referralService, referralId, null); | 
|  | 58 | +                query.Remove("rid"); | 
|  | 59 | +                var builder = new UriBuilder(context.Request.GetEncodedUrl()) | 
|  | 60 | +                { | 
|  | 61 | +                    Query = query.ToString() | 
|  | 62 | +                }; | 
|  | 63 | +                context.Response.Redirect(builder.ToString()); | 
|  | 64 | +                return; | 
|  | 65 | +            } | 
|  | 66 | +        } | 
|  | 67 | + | 
|  | 68 | +        await _Next(context); | 
|  | 69 | + | 
|  | 70 | +        static async Task TrackReferralAsync(IReferralService referralService, string? referralId, ClaimsPrincipal? claimsUser) | 
|  | 71 | +        { | 
|  | 72 | +            if (!string.IsNullOrWhiteSpace(referralId)) | 
|  | 73 | +            { | 
|  | 74 | +                _ = await referralService.TrackReferralAsync(referralId, claimsUser); | 
|  | 75 | +            } | 
|  | 76 | +        } | 
|  | 77 | +    } | 
|  | 78 | +} | 
0 commit comments