Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public CloudSessionData(string? lobbyId, EncryptedSessionSettings sessionSetting

public EncryptedSessionSettings SessionSettings { get; set; } = null!;

public int ProtocolVersion { get; set; }

public bool IsSessionActivated { get; set; }

public List<SessionMemberData> SessionMembers { get; set; }
Expand Down Expand Up @@ -77,4 +79,4 @@ public CloudSession GetCloudSession()
{
return SessionMembers.Concat(PreSessionMembers).Distinct().SingleOrDefault(m => m.ClientInstanceId == clientInstanceId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ public async Task<CloudSessionResult> Handle(CreateSessionRequest request, Cance
SessionMemberData creatorData;

cloudSessionData = new CloudSessionData(createCloudSessionParameters.LobbyId, createCloudSessionParameters.SessionSettings, client);
cloudSessionData.ProtocolVersion = createCloudSessionParameters.CreatorPublicKeyInfo.ProtocolVersion;
creatorData = new SessionMemberData(client, createCloudSessionParameters.CreatorPublicKeyInfo,
createCloudSessionParameters.CreatorProfileClientId, cloudSessionData,
createCloudSessionParameters.CreatorPrivateData);
Expand Down Expand Up @@ -67,4 +68,4 @@ private string GenerateRandomSessionId()

return sessionId;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ByteSync.Common.Business.Sessions.Cloud.Connections;
using ByteSync.Common.Business.Versions;
using ByteSync.ServerCommon.Interfaces.Repositories;
using ByteSync.ServerCommon.Interfaces.Services.Clients;
using MediatR;
Expand All @@ -12,8 +11,8 @@ public class StartTrustCheckCommandHandler : IRequestHandler<StartTrustCheckRequ
private readonly ICloudSessionsRepository _cloudSessionsRepository;
private readonly IInvokeClientsService _invokeClientsService;
private readonly ILogger<StartTrustCheckCommandHandler> _logger;

public StartTrustCheckCommandHandler(ICloudSessionsRepository cloudSessionsRepository, IInvokeClientsService invokeClientsService,
public StartTrustCheckCommandHandler(ICloudSessionsRepository cloudSessionsRepository, IInvokeClientsService invokeClientsService,
ILogger<StartTrustCheckCommandHandler> logger)
{
_cloudSessionsRepository = cloudSessionsRepository;
Expand All @@ -31,37 +30,20 @@ public async Task<StartTrustCheckResult> Handle(StartTrustCheckRequest request,
{
return new StartTrustCheckResult { IsOK = false };
}

var joinerProtocolVersion = trustCheckParameters.ProtocolVersion;
var sessionProtocolVersion = cloudSession.ProtocolVersion;

if (!ProtocolVersion.IsCompatible(joinerProtocolVersion))
if (joinerProtocolVersion != sessionProtocolVersion)
{
_logger.LogWarning(
"StartTrustCheck: Joiner {JoinerId} has incompatible protocol version {JoinerVersion}",
joiner.ClientInstanceId, joinerProtocolVersion);
"StartTrustCheck: Joiner {JoinerId} has incompatible protocol version {JoinerVersion} for session {SessionId} (version {SessionVersion})",
joiner.ClientInstanceId, joinerProtocolVersion, trustCheckParameters.SessionId, sessionProtocolVersion);

return new StartTrustCheckResult { IsOK = false, IsProtocolVersionIncompatible = true };
}

var membersToCheck = cloudSession.SessionMembers
.Where(m => trustCheckParameters.MembersInstanceIdsToCheck.Contains(m.ClientInstanceId));

foreach (var member in membersToCheck)
{
var memberProtocolVersion = member.PublicKeyInfo.ProtocolVersion;

if (!ProtocolVersion.IsCompatible(memberProtocolVersion) ||
memberProtocolVersion != joinerProtocolVersion)
{
_logger.LogWarning(
"StartTrustCheck: Protocol version mismatch between joiner {JoinerId} (version {JoinerVersion}) and member {MemberId} (version {MemberVersion})",
joiner.ClientInstanceId, joinerProtocolVersion, member.ClientInstanceId, memberProtocolVersion);

return new StartTrustCheckResult { IsOK = false, IsProtocolVersionIncompatible = true };
}
}

_logger.LogInformation("StartTrustCheck: {Joiner} starts trust check for session {SessionId}. {Count} members to check",
_logger.LogInformation("StartTrustCheck: {Joiner} starts trust check for session {SessionId}. {Count} members to check",
joiner.ClientInstanceId, trustCheckParameters.SessionId, trustCheckParameters.MembersInstanceIdsToCheck.Count);

var validMemberIds = trustCheckParameters.MembersInstanceIdsToCheck
Expand All @@ -70,10 +52,11 @@ public async Task<StartTrustCheckResult> Handle(StartTrustCheckRequest request,

foreach (var clientInstanceId in validMemberIds)
{
_logger.LogInformation("StartTrustCheck: {Member} must be trusted by {Joiner}",
_logger.LogInformation("StartTrustCheck: {Member} must be trusted by {Joiner}",
clientInstanceId, joiner.ClientInstanceId);

await _invokeClientsService.Client(clientInstanceId).AskPublicKeyCheckData(trustCheckParameters.SessionId, joiner.ClientInstanceId,
await _invokeClientsService.Client(clientInstanceId).AskPublicKeyCheckData(trustCheckParameters.SessionId,
joiner.ClientInstanceId,
trustCheckParameters.PublicKeyInfo).ConfigureAwait(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public async Task Handle_ValidRequest_CreatesSession()
var lobbyId = "lobbyId";
var sessionSettings = new EncryptedSessionSettings();
var client = new Client { ClientInstanceId = "clientInstance1" };
var creatorPublicKeyInfo = new PublicKeyInfo();
var creatorPublicKeyInfo = new PublicKeyInfo { ProtocolVersion = 2 };
var creatorProfileClientId = "creatorProfile";
var creatorPrivateData = new EncryptedSessionMemberPrivateData();
var sessionId = "123ABC456";
Expand Down Expand Up @@ -108,6 +108,7 @@ public async Task Handle_ValidRequest_CreatesSession()
addedCloudSession.Should().NotBeNull();
addedCloudSession.LobbyId.Should().Be(lobbyId);
addedCloudSession.SessionSettings.Should().BeSameAs(sessionSettings);
addedCloudSession.ProtocolVersion.Should().Be(creatorPublicKeyInfo.ProtocolVersion);
addedCloudSession.SessionMembers.Should().HaveCount(1);

// Verify session member creation
Expand Down Expand Up @@ -136,4 +137,4 @@ public async Task Handle_ValidRequest_CreatesSession()
A<SessionMemberData>.That.Matches(m => m == creatorMemberData)))
.MustHaveHappenedOnceExactly();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class StartTrustCheckCommandHandlerTests
private readonly IHubByteSyncPush _mockHubByteSyncPush;

private readonly StartTrustCheckCommandHandler _startTrustCheckCommandHandler;

public StartTrustCheckCommandHandlerTests()
{
_mockCloudSessionsRepository = A.Fake<ICloudSessionsRepository>();
Expand All @@ -32,7 +32,7 @@ public StartTrustCheckCommandHandlerTests()
_mockHubByteSyncPush = A.Fake<IHubByteSyncPush>();

_startTrustCheckCommandHandler = new StartTrustCheckCommandHandler(
_mockCloudSessionsRepository,
_mockCloudSessionsRepository,
_mockInvokeClientsService,
_mockLogger);
}
Expand All @@ -57,13 +57,14 @@ public async Task Handle_SessionExists_WithMembers_ReturnsSuccessResult()
};

var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "member1" });
cloudSession.SessionMembers.Add(new SessionMemberData
{
cloudSession.ProtocolVersion = ProtocolVersion.CURRENT;
cloudSession.SessionMembers.Add(new SessionMemberData
{
ClientInstanceId = member1,
PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT }
});
cloudSession.SessionMembers.Add(new SessionMemberData
{
cloudSession.SessionMembers.Add(new SessionMemberData
{
ClientInstanceId = member2,
PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT }
});
Expand All @@ -75,7 +76,7 @@ public async Task Handle_SessionExists_WithMembers_ReturnsSuccessResult()
.Returns(_mockHubByteSyncPush);
A.CallTo(() => _mockInvokeClientsService.Client(member2))
.Returns(_mockHubByteSyncPush);

A.CallTo(() => _mockHubByteSyncPush.AskPublicKeyCheckData(sessionId, joinerClient.ClientInstanceId, publicKeyInfo))
.Returns(Task.CompletedTask);

Expand Down Expand Up @@ -149,8 +150,9 @@ public async Task Handle_SessionExistsButNoValidMembers_ReturnsEmptySuccessResul
};

var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "creator" });
cloudSession.SessionMembers.Add(new SessionMemberData
{
cloudSession.ProtocolVersion = ProtocolVersion.CURRENT;
cloudSession.SessionMembers.Add(new SessionMemberData
{
ClientInstanceId = "otherMember",
PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT }
});
Expand All @@ -173,7 +175,7 @@ public async Task Handle_SessionExistsButNoValidMembers_ReturnsEmptySuccessResul
}

[Test]
public async Task Handle_WhenMemberHasIncompatibleProtocolVersion_ReturnsProtocolVersionIncompatible()
public async Task Handle_WhenMemberHasDifferentProtocolVersion_ReturnsSuccess()
{
var sessionId = "testSession";
var joinerClient = new Client { ClientId = "joinerClient", ClientInstanceId = "joinerClientInstance" };
Expand All @@ -189,49 +191,58 @@ public async Task Handle_WhenMemberHasIncompatibleProtocolVersion_ReturnsProtoco
};

var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "creator" });
cloudSession.SessionMembers.Add(new SessionMemberData
{
cloudSession.ProtocolVersion = ProtocolVersion.CURRENT;
cloudSession.SessionMembers.Add(new SessionMemberData
{
ClientInstanceId = member1,
PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = 0 }
});

A.CallTo(() => _mockCloudSessionsRepository.Get(sessionId))
.Returns(cloudSession);

A.CallTo(() => _mockInvokeClientsService.Client(member1))
.Returns(_mockHubByteSyncPush);
A.CallTo(() => _mockHubByteSyncPush.AskPublicKeyCheckData(sessionId, joinerClient.ClientInstanceId, publicKeyInfo))
.Returns(Task.CompletedTask);

var request = new StartTrustCheckRequest(parameters, joinerClient);

var result = await _startTrustCheckCommandHandler.Handle(request, CancellationToken.None);

result.Should().NotBeNull();
result.IsOK.Should().BeFalse();
result.IsProtocolVersionIncompatible.Should().BeTrue();
result.MembersInstanceIds.Should().BeEmpty();
result.IsOK.Should().BeTrue();
result.IsProtocolVersionIncompatible.Should().BeFalse();
result.MembersInstanceIds.Should().ContainSingle().Which.Should().Be(member1);

A.CallTo(() => _mockCloudSessionsRepository.Get(sessionId)).MustHaveHappenedOnceExactly();
A.CallTo(() => _mockInvokeClientsService.Client(A<string>.Ignored)).MustNotHaveHappened();
A.CallTo(() => _mockInvokeClientsService.Client(member1)).MustHaveHappenedOnceExactly();
A.CallTo(() => _mockHubByteSyncPush.AskPublicKeyCheckData(sessionId, joinerClient.ClientInstanceId, publicKeyInfo))
.MustHaveHappenedOnceExactly();
}

[Test]
public async Task Handle_WhenJoinerHasIncompatibleProtocolVersion_ReturnsProtocolVersionIncompatible()
public async Task Handle_WhenJoinerProtocolVersionDoesNotMatchSessionProtocolVersion_ReturnsProtocolVersionIncompatible()
{
var sessionId = "testSession";
var joinerClient = new Client { ClientId = "joinerClient", ClientInstanceId = "joinerClientInstance" };
var member1 = "memberInstance1";

var publicKeyInfo = new PublicKeyInfo { ProtocolVersion = 0 };
var publicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT };
var parameters = new TrustCheckParameters
{
SessionId = sessionId,
MembersInstanceIdsToCheck = new List<string> { member1 },
PublicKeyInfo = publicKeyInfo,
ProtocolVersion = 0
ProtocolVersion = ProtocolVersion.CURRENT
};

var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "creator" });
cloudSession.SessionMembers.Add(new SessionMemberData
{
cloudSession.ProtocolVersion = 0;
cloudSession.SessionMembers.Add(new SessionMemberData
{
ClientInstanceId = member1,
PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT }
PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = 0 }
});

A.CallTo(() => _mockCloudSessionsRepository.Get(sessionId))
Expand Down Expand Up @@ -267,8 +278,9 @@ public async Task Handle_WhenJoinerAndMemberHaveCompatibleVersions_ReturnsSucces
};

var cloudSession = new CloudSessionData(sessionId, new EncryptedSessionSettings(), new Client { ClientInstanceId = "creator" });
cloudSession.SessionMembers.Add(new SessionMemberData
{
cloudSession.ProtocolVersion = ProtocolVersion.CURRENT;
cloudSession.SessionMembers.Add(new SessionMemberData
{
ClientInstanceId = member1,
PublicKeyInfo = new PublicKeyInfo { ProtocolVersion = ProtocolVersion.CURRENT }
});
Expand Down
Loading