diff --git a/src/ByteSync.ServerCommon/Business/Sessions/CloudSessionData.cs b/src/ByteSync.ServerCommon/Business/Sessions/CloudSessionData.cs index 91659fc8..808229e8 100644 --- a/src/ByteSync.ServerCommon/Business/Sessions/CloudSessionData.cs +++ b/src/ByteSync.ServerCommon/Business/Sessions/CloudSessionData.cs @@ -29,6 +29,8 @@ public CloudSessionData(string? lobbyId, EncryptedSessionSettings sessionSetting public EncryptedSessionSettings SessionSettings { get; set; } = null!; + public int ProtocolVersion { get; set; } + public bool IsSessionActivated { get; set; } public List SessionMembers { get; set; } @@ -77,4 +79,4 @@ public CloudSession GetCloudSession() { return SessionMembers.Concat(PreSessionMembers).Distinct().SingleOrDefault(m => m.ClientInstanceId == clientInstanceId); } -} \ No newline at end of file +} diff --git a/src/ByteSync.ServerCommon/Commands/CloudSessions/CreateSessionCommandHandler.cs b/src/ByteSync.ServerCommon/Commands/CloudSessions/CreateSessionCommandHandler.cs index 5c18d81a..66215c72 100644 --- a/src/ByteSync.ServerCommon/Commands/CloudSessions/CreateSessionCommandHandler.cs +++ b/src/ByteSync.ServerCommon/Commands/CloudSessions/CreateSessionCommandHandler.cs @@ -39,6 +39,7 @@ public async Task Handle(CreateSessionRequest request, Cance SessionMemberData creatorData; cloudSessionData = new CloudSessionData(createCloudSessionParameters.LobbyId, createCloudSessionParameters.SessionSettings, client); + cloudSessionData.ProtocolVersion = createCloudSessionParameters.CreatorPublicKeyInfo.ProtocolVersion; creatorData = new SessionMemberData(client, createCloudSessionParameters.CreatorPublicKeyInfo, createCloudSessionParameters.CreatorProfileClientId, cloudSessionData, createCloudSessionParameters.CreatorPrivateData); @@ -67,4 +68,4 @@ private string GenerateRandomSessionId() return sessionId; } -} \ No newline at end of file +} diff --git a/src/ByteSync.ServerCommon/Commands/Trusts/StartTrustCheckCommandHandler.cs b/src/ByteSync.ServerCommon/Commands/Trusts/StartTrustCheckCommandHandler.cs index 04666d3d..36797eb0 100644 --- a/src/ByteSync.ServerCommon/Commands/Trusts/StartTrustCheckCommandHandler.cs +++ b/src/ByteSync.ServerCommon/Commands/Trusts/StartTrustCheckCommandHandler.cs @@ -1,5 +1,4 @@ using ByteSync.Common.Business.Sessions.Cloud.Connections; -using ByteSync.Common.Business.Versions; using ByteSync.ServerCommon.Interfaces.Repositories; using ByteSync.ServerCommon.Interfaces.Services.Clients; using MediatR; @@ -12,8 +11,8 @@ public class StartTrustCheckCommandHandler : IRequestHandler _logger; - - public StartTrustCheckCommandHandler(ICloudSessionsRepository cloudSessionsRepository, IInvokeClientsService invokeClientsService, + + public StartTrustCheckCommandHandler(ICloudSessionsRepository cloudSessionsRepository, IInvokeClientsService invokeClientsService, ILogger logger) { _cloudSessionsRepository = cloudSessionsRepository; @@ -31,37 +30,20 @@ public async Task Handle(StartTrustCheckRequest request, { return new StartTrustCheckResult { IsOK = false }; } - + var joinerProtocolVersion = trustCheckParameters.ProtocolVersion; + var sessionProtocolVersion = cloudSession.ProtocolVersion; - if (!ProtocolVersion.IsCompatible(joinerProtocolVersion)) + if (joinerProtocolVersion != sessionProtocolVersion) { _logger.LogWarning( - "StartTrustCheck: Joiner {JoinerId} has incompatible protocol version {JoinerVersion}", - joiner.ClientInstanceId, joinerProtocolVersion); + "StartTrustCheck: Joiner {JoinerId} has incompatible protocol version {JoinerVersion} for session {SessionId} (version {SessionVersion})", + joiner.ClientInstanceId, joinerProtocolVersion, trustCheckParameters.SessionId, sessionProtocolVersion); return new StartTrustCheckResult { IsOK = false, IsProtocolVersionIncompatible = true }; } - var membersToCheck = cloudSession.SessionMembers - .Where(m => trustCheckParameters.MembersInstanceIdsToCheck.Contains(m.ClientInstanceId)); - - foreach (var member in membersToCheck) - { - var memberProtocolVersion = member.PublicKeyInfo.ProtocolVersion; - - if (!ProtocolVersion.IsCompatible(memberProtocolVersion) || - memberProtocolVersion != joinerProtocolVersion) - { - _logger.LogWarning( - "StartTrustCheck: Protocol version mismatch between joiner {JoinerId} (version {JoinerVersion}) and member {MemberId} (version {MemberVersion})", - joiner.ClientInstanceId, joinerProtocolVersion, member.ClientInstanceId, memberProtocolVersion); - - return new StartTrustCheckResult { IsOK = false, IsProtocolVersionIncompatible = true }; - } - } - - _logger.LogInformation("StartTrustCheck: {Joiner} starts trust check for session {SessionId}. {Count} members to check", + _logger.LogInformation("StartTrustCheck: {Joiner} starts trust check for session {SessionId}. {Count} members to check", joiner.ClientInstanceId, trustCheckParameters.SessionId, trustCheckParameters.MembersInstanceIdsToCheck.Count); var validMemberIds = trustCheckParameters.MembersInstanceIdsToCheck @@ -70,10 +52,11 @@ public async Task Handle(StartTrustCheckRequest request, foreach (var clientInstanceId in validMemberIds) { - _logger.LogInformation("StartTrustCheck: {Member} must be trusted by {Joiner}", + _logger.LogInformation("StartTrustCheck: {Member} must be trusted by {Joiner}", clientInstanceId, joiner.ClientInstanceId); - await _invokeClientsService.Client(clientInstanceId).AskPublicKeyCheckData(trustCheckParameters.SessionId, joiner.ClientInstanceId, + await _invokeClientsService.Client(clientInstanceId).AskPublicKeyCheckData(trustCheckParameters.SessionId, + joiner.ClientInstanceId, trustCheckParameters.PublicKeyInfo).ConfigureAwait(false); } diff --git a/tests/ByteSync.ServerCommon.Tests/Commands/CloudSessions/CreateSessionCommandHandlerTests.cs b/tests/ByteSync.ServerCommon.Tests/Commands/CloudSessions/CreateSessionCommandHandlerTests.cs index 9f1f7b2e..e59008bf 100644 --- a/tests/ByteSync.ServerCommon.Tests/Commands/CloudSessions/CreateSessionCommandHandlerTests.cs +++ b/tests/ByteSync.ServerCommon.Tests/Commands/CloudSessions/CreateSessionCommandHandlerTests.cs @@ -54,7 +54,7 @@ public async Task Handle_ValidRequest_CreatesSession() var lobbyId = "lobbyId"; var sessionSettings = new EncryptedSessionSettings(); var client = new Client { ClientInstanceId = "clientInstance1" }; - var creatorPublicKeyInfo = new PublicKeyInfo(); + var creatorPublicKeyInfo = new PublicKeyInfo { ProtocolVersion = 2 }; var creatorProfileClientId = "creatorProfile"; var creatorPrivateData = new EncryptedSessionMemberPrivateData(); var sessionId = "123ABC456"; @@ -108,6 +108,7 @@ public async Task Handle_ValidRequest_CreatesSession() addedCloudSession.Should().NotBeNull(); addedCloudSession.LobbyId.Should().Be(lobbyId); addedCloudSession.SessionSettings.Should().BeSameAs(sessionSettings); + addedCloudSession.ProtocolVersion.Should().Be(creatorPublicKeyInfo.ProtocolVersion); addedCloudSession.SessionMembers.Should().HaveCount(1); // Verify session member creation @@ -136,4 +137,4 @@ public async Task Handle_ValidRequest_CreatesSession() A.That.Matches(m => m == creatorMemberData))) .MustHaveHappenedOnceExactly(); } -} \ No newline at end of file +} diff --git a/tests/ByteSync.ServerCommon.Tests/Commands/Trusts/StartTrustCheckCommandHandlerTests.cs b/tests/ByteSync.ServerCommon.Tests/Commands/Trusts/StartTrustCheckCommandHandlerTests.cs index 870293bb..629f4f5b 100644 --- a/tests/ByteSync.ServerCommon.Tests/Commands/Trusts/StartTrustCheckCommandHandlerTests.cs +++ b/tests/ByteSync.ServerCommon.Tests/Commands/Trusts/StartTrustCheckCommandHandlerTests.cs @@ -23,7 +23,7 @@ public class StartTrustCheckCommandHandlerTests private readonly IHubByteSyncPush _mockHubByteSyncPush; private readonly StartTrustCheckCommandHandler _startTrustCheckCommandHandler; - + public StartTrustCheckCommandHandlerTests() { _mockCloudSessionsRepository = A.Fake(); @@ -32,7 +32,7 @@ public StartTrustCheckCommandHandlerTests() _mockHubByteSyncPush = A.Fake(); _startTrustCheckCommandHandler = new StartTrustCheckCommandHandler( - _mockCloudSessionsRepository, + _mockCloudSessionsRepository, _mockInvokeClientsService, _mockLogger); } @@ -57,13 +57,14 @@ public async Task Handle_SessionExists_WithMembers_ReturnsSuccessResult() }; var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "member1" }); - cloudSession.SessionMembers.Add(new SessionMemberData - { + cloudSession.ProtocolVersion = ProtocolVersion.CURRENT; + cloudSession.SessionMembers.Add(new SessionMemberData + { ClientInstanceId = member1, PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT } }); - cloudSession.SessionMembers.Add(new SessionMemberData - { + cloudSession.SessionMembers.Add(new SessionMemberData + { ClientInstanceId = member2, PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT } }); @@ -75,7 +76,7 @@ public async Task Handle_SessionExists_WithMembers_ReturnsSuccessResult() .Returns(_mockHubByteSyncPush); A.CallTo(() => _mockInvokeClientsService.Client(member2)) .Returns(_mockHubByteSyncPush); - + A.CallTo(() => _mockHubByteSyncPush.AskPublicKeyCheckData(sessionId, joinerClient.ClientInstanceId, publicKeyInfo)) .Returns(Task.CompletedTask); @@ -149,8 +150,9 @@ public async Task Handle_SessionExistsButNoValidMembers_ReturnsEmptySuccessResul }; var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "creator" }); - cloudSession.SessionMembers.Add(new SessionMemberData - { + cloudSession.ProtocolVersion = ProtocolVersion.CURRENT; + cloudSession.SessionMembers.Add(new SessionMemberData + { ClientInstanceId = "otherMember", PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT } }); @@ -173,7 +175,7 @@ public async Task Handle_SessionExistsButNoValidMembers_ReturnsEmptySuccessResul } [Test] - public async Task Handle_WhenMemberHasIncompatibleProtocolVersion_ReturnsProtocolVersionIncompatible() + public async Task Handle_WhenMemberHasDifferentProtocolVersion_ReturnsSuccess() { var sessionId = "testSession"; var joinerClient = new Client { ClientId = "joinerClient", ClientInstanceId = "joinerClientInstance" }; @@ -189,8 +191,9 @@ public async Task Handle_WhenMemberHasIncompatibleProtocolVersion_ReturnsProtoco }; var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "creator" }); - cloudSession.SessionMembers.Add(new SessionMemberData - { + cloudSession.ProtocolVersion = ProtocolVersion.CURRENT; + cloudSession.SessionMembers.Add(new SessionMemberData + { ClientInstanceId = member1, PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = 0 } }); @@ -198,40 +201,48 @@ public async Task Handle_WhenMemberHasIncompatibleProtocolVersion_ReturnsProtoco A.CallTo(() => _mockCloudSessionsRepository.Get(sessionId)) .Returns(cloudSession); + A.CallTo(() => _mockInvokeClientsService.Client(member1)) + .Returns(_mockHubByteSyncPush); + A.CallTo(() => _mockHubByteSyncPush.AskPublicKeyCheckData(sessionId, joinerClient.ClientInstanceId, publicKeyInfo)) + .Returns(Task.CompletedTask); + var request = new StartTrustCheckRequest(parameters, joinerClient); var result = await _startTrustCheckCommandHandler.Handle(request, CancellationToken.None); result.Should().NotBeNull(); - result.IsOK.Should().BeFalse(); - result.IsProtocolVersionIncompatible.Should().BeTrue(); - result.MembersInstanceIds.Should().BeEmpty(); + result.IsOK.Should().BeTrue(); + result.IsProtocolVersionIncompatible.Should().BeFalse(); + result.MembersInstanceIds.Should().ContainSingle().Which.Should().Be(member1); A.CallTo(() => _mockCloudSessionsRepository.Get(sessionId)).MustHaveHappenedOnceExactly(); - A.CallTo(() => _mockInvokeClientsService.Client(A.Ignored)).MustNotHaveHappened(); + A.CallTo(() => _mockInvokeClientsService.Client(member1)).MustHaveHappenedOnceExactly(); + A.CallTo(() => _mockHubByteSyncPush.AskPublicKeyCheckData(sessionId, joinerClient.ClientInstanceId, publicKeyInfo)) + .MustHaveHappenedOnceExactly(); } [Test] - public async Task Handle_WhenJoinerHasIncompatibleProtocolVersion_ReturnsProtocolVersionIncompatible() + public async Task Handle_WhenJoinerProtocolVersionDoesNotMatchSessionProtocolVersion_ReturnsProtocolVersionIncompatible() { var sessionId = "testSession"; var joinerClient = new Client { ClientId = "joinerClient", ClientInstanceId = "joinerClientInstance" }; var member1 = "memberInstance1"; - var publicKeyInfo = new PublicKeyInfo { ProtocolVersion = 0 }; + var publicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT }; var parameters = new TrustCheckParameters { SessionId = sessionId, MembersInstanceIdsToCheck = new List { member1 }, PublicKeyInfo = publicKeyInfo, - ProtocolVersion = 0 + ProtocolVersion = ProtocolVersion.CURRENT }; var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "creator" }); - cloudSession.SessionMembers.Add(new SessionMemberData - { + cloudSession.ProtocolVersion = 0; + cloudSession.SessionMembers.Add(new SessionMemberData + { ClientInstanceId = member1, - PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT } + PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = 0 } }); A.CallTo(() => _mockCloudSessionsRepository.Get(sessionId)) @@ -267,8 +278,9 @@ public async Task Handle_WhenJoinerAndMemberHaveCompatibleVersions_ReturnsSucces }; var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "creator" }); - cloudSession.SessionMembers.Add(new SessionMemberData - { + cloudSession.ProtocolVersion = ProtocolVersion.CURRENT; + cloudSession.SessionMembers.Add(new SessionMemberData + { ClientInstanceId = member1, PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT } });