Skip to content

Commit 13d5ab8

Browse files
borisahrensBoris Ahrens
andauthored
Return list as i async enumerable (#305)
* Use IAsyncEnumerable * dotnet format --------- Co-authored-by: Boris Ahrens <boris.ahrens@gdata.de>
1 parent 470ac09 commit 13d5ab8

File tree

8 files changed

+71
-71
lines changed

8 files changed

+71
-71
lines changed

src/MalwareSampleExchange.Console/Controllers/TokensApiController.cs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ public sealed class TokensApiController : Controller
3131
/// </summary>
3232
/// <param name="logger"></param>
3333
/// <param name="listRequester"></param>
34-
/// <param name="partnerProvider"></param>
3534
public TokensApiController(ILogger<TokensApiController> logger, IListRequester listRequester)
3635
{
3736
_logger = logger;
@@ -43,50 +42,52 @@ public TokensApiController(ILogger<TokensApiController> logger, IListRequester l
4342
/// </summary>
4443
/// <param name="start">Start date for sample request range</param>
4544
/// <param name="end">(Default: today's date )</param>
45+
/// <param name="token"></param>
4646
[HttpGet]
4747
[Route("/v1/list")]
4848
[ValidateModelState]
4949
[SwaggerResponse(StatusCodes.Status400BadRequest, "Start Date has to be before end date")]
5050
[SwaggerResponse(StatusCodes.Status401Unauthorized, "Unauthorized!")]
5151
[SwaggerResponse(StatusCodes.Status402PaymentRequired, "Start date cannot be older than 7 days")]
5252
[SwaggerResponse(StatusCodes.Status500InternalServerError, "We encountered an error while processing the request")]
53-
public async Task<IActionResult> ListTokens([FromQuery][Required] DateTime start, [FromQuery] DateTime? end, CancellationToken token = default)
53+
public Task<IActionResult> ListTokens([FromQuery][Required] DateTime start, [FromQuery] DateTime? end, CancellationToken token = default)
5454
{
5555
try
5656
{
5757
_logger.LogDebug(LogEvents.HashListRequest, "Incoming ListRequest");
5858
var claimsIdentity = User.Identity as ClaimsIdentity;
59-
var username = claimsIdentity?.FindFirst(ClaimTypes.Name)?.Value ?? String.Empty;
59+
var username = claimsIdentity?.FindFirst(ClaimTypes.Name)?.Value ?? string.Empty;
6060
if (start >= end)
6161
{
62-
return StatusCode(400, new Error
62+
return Task.FromResult<IActionResult>(StatusCode(400, new Error
6363
{
6464
Code = 400,
6565
Message = "Start Date has to be before end date."
66-
});
66+
}));
6767
}
6868

6969
if (start < DateTime.Now.AddDays(-7))
7070
{
71-
return StatusCode(402, new Error
71+
return Task.FromResult<IActionResult>(StatusCode(402, new Error
7272
{
7373
Code = 402,
7474
Message = "Start date cannot be older than 7 days."
75-
});
75+
}));
7676
}
7777

78-
var tokens = await _listRequester.RequestListAsync(username, start, end, token);
79-
_listRequestHistogram.WithLabels(username).Set(tokens.Count);
80-
return Ok(tokens);
78+
var tokens = _listRequester
79+
.RequestList(username, start, end, token)
80+
.CountAndCallAsync(count => _listRequestHistogram.WithLabels(username).Set(count)); ;
81+
return Task.FromResult<IActionResult>(Ok(tokens));
8182
}
8283
catch (Exception e)
8384
{
8485
_logger.LogError(LogEvents.InternalServerError, e, "Something went wrong. The Customer got an 500.");
85-
return StatusCode(500, new Error
86+
return Task.FromResult<IActionResult>(StatusCode(500, new Error
8687
{
8788
Code = 500,
8889
Message = "We encountered an error while processing the request."
89-
});
90+
}));
9091
}
9192
}
9293
}

src/MalwareSampleExchange.Console/Database/ISampleMetadataHandler.cs

Lines changed: 0 additions & 24 deletions
This file was deleted.

src/MalwareSampleExchange.Console/Database/MongoMetadataHandler.cs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,21 @@
99

1010
namespace MalwareSampleExchange.Console.Database;
1111

12+
public interface ISampleMetadataHandler
13+
{
14+
/// <summary>
15+
///
16+
/// </summary>
17+
/// <param name="start"></param>
18+
/// <param name="end"></param>
19+
/// <param name="sampleSet"></param>
20+
/// <param name="token"></param>
21+
/// <returns></returns>
22+
Task<IEnumerable<ExportSample>> GetSamplesAsync(DateTime start, DateTime? end, string? sampleSet, CancellationToken token = default);
23+
24+
Task InsertSampleAsync(RequestExportSample sample, CancellationToken token = default);
25+
}
26+
1227
public class MongoMetadataHandler : ISampleMetadataHandler, IHostedService
1328
{
1429
private readonly MongoMetadataOptions _options;
@@ -20,16 +35,16 @@ public MongoMetadataHandler(IOptions<MongoMetadataOptions> options)
2035
_mongoClient = new MongoClient(options.Value.ConnectionString);
2136
}
2237

23-
public async Task<IEnumerable<ExportSample>> GetSamplesAsync(DateTime start, DateTime? end, string sampleSet, CancellationToken token = default)
38+
public async Task<IEnumerable<ExportSample>> GetSamplesAsync(DateTime start, DateTime? end, string? sampleSet, CancellationToken token = default)
2439
{
2540
var mongoDatabase = _mongoClient.GetDatabase(_options.DatabaseName);
2641
var sampleCollection = mongoDatabase.GetCollection<ExportSample>(_options.CollectionName);
2742
var list = end == null
2843
? await sampleCollection
29-
.FindAsync(_ => _.SampleSet == sampleSet && _.Imported >= start, cancellationToken: token)
44+
.FindAsync(sample => sample.SampleSet == sampleSet && sample.Imported >= start, cancellationToken: token)
3045
: await sampleCollection
31-
.FindAsync(_ => _.SampleSet == sampleSet && _.Imported >= start && _.Imported <= end, cancellationToken: token);
32-
return list.ToList();
46+
.FindAsync(sample => sample.SampleSet == sampleSet && sample.Imported >= start && sample.Imported <= end, cancellationToken: token);
47+
return list.ToList(cancellationToken: token);
3348
}
3449

3550
public async Task InsertSampleAsync(RequestExportSample sample, CancellationToken token = default)

src/MalwareSampleExchange.Console/ListRequester/IListRequester.cs

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/MalwareSampleExchange.Console/ListRequester/ListRequester.cs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Linq;
4+
using System.Runtime.CompilerServices;
45
using System.Threading;
56
using System.Threading.Tasks;
67
using JWT.Algorithms;
@@ -13,6 +14,11 @@
1314

1415
namespace MalwareSampleExchange.Console.ListRequester;
1516

17+
public interface IListRequester
18+
{
19+
IAsyncEnumerable<Token> RequestList(string username, DateTime start, DateTime? end, CancellationToken token = default);
20+
}
21+
1622
public class ListRequester : IListRequester
1723
{
1824
private readonly ILogger _logger;
@@ -32,15 +38,14 @@ public ListRequester(ILogger<ListRequester> logger, IOptions<ListRequesterOption
3238
_sampleStorageHandler = sampleStorageHandler;
3339
}
3440

35-
public async Task<List<Token>> RequestListAsync(string username, DateTime start, DateTime? end,
36-
CancellationToken token = default)
41+
public async IAsyncEnumerable<Token> RequestList(string username, DateTime start, DateTime? end,
42+
[EnumeratorCancellation] CancellationToken token = default)
3743
{
3844
var includeFamilyName = _partnerProvider.Partners.Single(_ => _.Name == username).IncludeFamilyName;
3945
var sampleSet = _partnerProvider.Partners.SingleOrDefault(_ => _.Name == username)?.Sampleset;
4046

4147
var samples = await _sampleMetadataHandler.GetSamplesAsync(start, end, sampleSet, token);
4248

43-
var tokens = new List<Token>();
4449
foreach (var sample in samples.Where(sample => sample.DoNotUseBefore <= DateTime.Now))
4550
{
4651
var fileSize = sample.FileSize == 0
@@ -66,13 +71,28 @@ public async Task<List<Token>> RequestListAsync(string username, DateTime start,
6671
{
6772
builder.AddClaim("familyname", sample.FamilyName);
6873
}
69-
tokens.Add(new Token
74+
75+
yield return new Token
7076
{
7177
_Token = builder.Encode()
72-
});
78+
};
7379
};
80+
}
81+
}
82+
83+
public static class EnumerableExtensions
84+
{
85+
public static async IAsyncEnumerable<T> CountAndCallAsync<T>(this IAsyncEnumerable<T> enumerable, Action<int> callback)
86+
{
87+
var count = 0;
88+
89+
await using var enumerator = enumerable.GetAsyncEnumerator();
90+
while (await enumerator.MoveNextAsync())
91+
{
92+
yield return enumerator.Current;
93+
count++;
94+
}
7495

75-
_logger.LogInformation(LogEvents.HashListResponse, $"Customer {username} receives a list with {tokens.Count} hashes.");
76-
return tokens;
96+
callback(count);
7797
}
7898
}

src/MalwareSampleExchange.Console/MalwareSampleExchange.Console.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<Project Sdk="Microsoft.NET.Sdk.Web">
22
<PropertyGroup>
33
<TargetFramework>net8.0</TargetFramework>
4-
<Version>0.8.0</Version>
4+
<Version>0.8.1</Version>
55
<Nullable>enable</Nullable>
66
</PropertyGroup>
77
<ItemGroup>

test/MalwareSampleExchange.Console_Test/MalwareSampleExchange.Console_Test.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
<ItemGroup>
77
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.1.0" />
88
<PackageReference Include="Moq" Version="4.17.2" />
9+
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
910
<PackageReference Include="Testcontainers" Version="3.0.0" />
1011
<PackageReference Include="xunit" Version="2.4.1" />
1112
<PackageReference Include="xunit.runner.visualstudio" Version="2.4.3">

test/MalwareSampleExchange.Console_Test/SampleExchangeTest.cs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.IO;
4+
using System.Linq;
45
using System.Security.Cryptography;
56
using System.Text;
67
using System.Net.Http;
@@ -40,11 +41,11 @@ private static PartnerProvider CreatePartnerProvider()
4041
return new PartnerProvider(Mock.Of<ILogger<PartnerProvider>>(),
4142
new OptionsWrapper<PartnerProviderOptions>(new PartnerProviderOptions
4243
{
43-
FilePath = Configuration["Config:FilePath"]
44+
FilePath = Configuration["Config:FilePath"]!
4445
}), Mock.Of<IHostApplicationLifetime>(), new HttpClient());
4546
}
4647

47-
private ISampleStorageHandler CreateSampleGetter()
48+
private static ISampleStorageHandler CreateSampleGetter()
4849
{
4950
var options = new FileStorageOptions();
5051
Configuration.GetSection("Storage").Bind(options);
@@ -129,8 +130,7 @@ public async void BusinessLogicCallback_GetSampleToken_NoFamilyName()
129130
var listRequester = CreateListRequester(reader);
130131

131132
var tokens = await listRequester
132-
.RequestListAsync("partner2", DateTime.Now.AddDays(-7),
133-
null);
133+
.RequestList("partner2", DateTime.Now.AddDays(-7), null).ToListAsync();
134134

135135
var deserializedToken = new JwtBuilder()
136136
.WithAlgorithm(new HMACSHA512Algorithm())
@@ -139,13 +139,12 @@ public async void BusinessLogicCallback_GetSampleToken_NoFamilyName()
139139
.Decode<IDictionary<string, object>>(tokens[0]._Token);
140140

141141
var sha256FromToken = deserializedToken["sha256"].ToString();
142-
var partnerFromToken = deserializedToken["partner"].ToString();
143-
var filesizeFromToken = long.Parse(deserializedToken["filesize"].ToString());
142+
var filesizeFromToken = long.Parse(deserializedToken["filesize"].ToString()!);
144143

145144
using (var sha256 = SHA256.Create())
146145
{
147146
sha256String = HexStringFromBytes(await sha256
148-
.ComputeHashAsync((await sampleGetter.GetAsync(sha256FromToken)).FileStream));
147+
.ComputeHashAsync((await sampleGetter.GetAsync(sha256FromToken!)).FileStream));
149148
}
150149

151150
Assert.Single(tokens);
@@ -167,20 +166,20 @@ public async void BusinessLogicCallback_GetSampleToken_HasFamilyName()
167166
var listRequester = CreateListRequester(reader);
168167

169168
var tokens = await listRequester
170-
.RequestListAsync("partnerWithFamilyName", DateTime.Now.AddDays(-7),
171-
null);
169+
.RequestList("partnerWithFamilyName", DateTime.Now.AddDays(-7),
170+
null).ToListAsync();
172171

173172
var deserializedToken = jwtBuilder.Decode<IDictionary<string, object>>(tokens[0]._Token);
174173

175174
var sha256FromToken = deserializedToken["sha256"].ToString();
176175
var partnerFromToken = deserializedToken["partner"].ToString();
177176

178-
var filesizeFromToken = long.Parse(deserializedToken["filesize"].ToString());
177+
var filesizeFromToken = long.Parse(deserializedToken["filesize"].ToString()!);
179178

180179
using (var sha256 = SHA256.Create())
181180
{
182181
sha256String = HexStringFromBytes(await sha256
183-
.ComputeHashAsync((await sampleGetter.GetAsync(sha256FromToken)).FileStream));
182+
.ComputeHashAsync((await sampleGetter.GetAsync(sha256FromToken!)).FileStream));
184183
}
185184

186185
Assert.Single(tokens);

0 commit comments

Comments
 (0)