Skip to content

Commit f2a335c

Browse files
authored
Reduce the number of calls made to hash the RuleFile (#612)
* Reduce number of calls to hash the analysis file. * Add Comments to Rule File * Clean up test. * Try explicitly calling SHA512Managed for #602 * Add Managed Crypto Tests
1 parent 9d3236c commit f2a335c

File tree

6 files changed

+167
-78
lines changed

6 files changed

+167
-78
lines changed

Benchmarks/CryptoTests.cs

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,31 @@ namespace Microsoft.CST.AttackSurfaceAnalyzer.Benchmarks
1313
public class CryptoTests : AsaDatabaseBenchmark
1414
{
1515
public CryptoTests()
16-
#nullable restore
1716
{
1817
}
1918

2019
// The number of iterations per run
2120
[Params(100000)]
2221
public int N { get; set; }
2322

23+
// The number of iterations per run
24+
[Params(1000)]
25+
public int NumObjects { get; set; }
26+
2427
// The amount of padding to add to the object in bytes Default size is approx 530 bytes serialized
2528
// Does not include SQL overhead
26-
[Params(0)]
29+
[Params(1000)]
2730
public int ObjectPadding { get; set; }
2831

2932
[Benchmark]
3033
public void Generate_N_Murmur_Hashes()
3134
{
3235
for (int i = 0; i < N; i++)
3336
{
34-
hashObjects.TryDequeue(out string? result);
35-
if (result is string)
37+
hashObjects.TryDequeue(out byte[]? result);
38+
if (result is byte[])
3639
{
37-
_ = murmur128.ComputeHash(Encoding.UTF8.GetBytes(result));
40+
_ = murmur128.ComputeHash(result);
3841
hashObjects.Enqueue(result);
3942
}
4043
else
@@ -49,10 +52,28 @@ public void Generate_N_SHA256_Hashes()
4952
{
5053
for (int i = 0; i < N; i++)
5154
{
52-
hashObjects.TryDequeue(out string? result);
53-
if (result is string)
55+
hashObjects.TryDequeue(out byte[]? result);
56+
if (result is byte[])
57+
{
58+
_ = sha256.ComputeHash(result);
59+
hashObjects.Enqueue(result);
60+
}
61+
else
62+
{
63+
Log.Information("The queue is polluted with nulls");
64+
}
65+
}
66+
}
67+
68+
[Benchmark]
69+
public void Generate_N_SHA256Managed_Hashes()
70+
{
71+
for (int i = 0; i < N; i++)
72+
{
73+
hashObjects.TryDequeue(out byte[]? result);
74+
if (result is byte[])
5475
{
55-
_ = sha256.ComputeHash(Encoding.UTF8.GetBytes(result));
76+
_ = sha256managed.ComputeHash(result);
5677
hashObjects.Enqueue(result);
5778
}
5879
else
@@ -67,10 +88,28 @@ public void Generate_N_SHA512_Hashes()
6788
{
6889
for (int i = 0; i < N; i++)
6990
{
70-
hashObjects.TryDequeue(out string? result);
71-
if (result is string)
91+
hashObjects.TryDequeue(out byte[]? result);
92+
if (result is byte[])
7293
{
73-
_ = sha512.ComputeHash(Encoding.UTF8.GetBytes(result));
94+
_ = sha512.ComputeHash(result);
95+
hashObjects.Enqueue(result);
96+
}
97+
else
98+
{
99+
Log.Information("The queue is polluted with nulls");
100+
}
101+
}
102+
}
103+
104+
[Benchmark]
105+
public void Generate_N_SHA512_Managed_Hashes()
106+
{
107+
for (int i = 0; i < N; i++)
108+
{
109+
hashObjects.TryDequeue(out byte[]? result);
110+
if (result is byte[])
111+
{
112+
_ = sha512managed.ComputeHash(result);
74113
hashObjects.Enqueue(result);
75114
}
76115
else
@@ -83,19 +122,22 @@ public void Generate_N_SHA512_Hashes()
83122
[GlobalSetup]
84123
public void GlobalSetup()
85124
{
86-
while (hashObjects.Count < N)
125+
while (hashObjects.Count < NumObjects)
87126
{
88-
hashObjects.Enqueue(JsonConvert.SerializeObject(GetRandomObject(ObjectPadding)));
127+
hashObjects.Enqueue(Encoding.UTF8.GetBytes(JsonConvert.SerializeObject(GetRandomObject(ObjectPadding))));
89128
}
90129
}
91130

92131
private static readonly HashAlgorithm murmur128 = MurmurHash.Create128();
93132

94133
private static readonly HashAlgorithm sha256 = SHA256.Create();
95134

135+
private static readonly HashAlgorithm sha256managed = SHA256Managed.Create();
136+
96137
private static readonly HashAlgorithm sha512 = SHA512.Create();
97138

98-
private readonly ConcurrentQueue<string> hashObjects = new ConcurrentQueue<string>();
99-
#nullable disable
139+
private static readonly HashAlgorithm sha512managed = SHA512Managed.Create();
140+
141+
private readonly ConcurrentQueue<byte[]> hashObjects = new ConcurrentQueue<byte[]>();
100142
}
101143
}

Benchmarks/Program.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ public class Program
66
{
77
public static void Main(string[] args)
88
{
9-
var summary = BenchmarkRunner.Run<InsertTestsWithoutTransactions>();
9+
var summary = BenchmarkRunner.Run<CryptoTests>();
1010
}
1111
}
1212
}

Cli/AttackSurfaceAnalyzerClient.cs

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@ private static void Main(string[] args)
126126
Environment.Exit((int)argsResult);
127127
}
128128

129+
/// <summary>
130+
/// Loads the rules from the provided file, if it is not null or empty. Or falls back to the embedded rules if it is.
131+
/// </summary>
132+
/// <param name="analysisFile"></param>
133+
/// <returns>The loaded RuleFile</returns>
134+
private static RuleFile LoadRulesFromFileOrEmbedded(string? analysisFile) => string.IsNullOrEmpty(analysisFile) ? RuleFile.LoadEmbeddedFilters() : RuleFile.FromFile(analysisFile);
135+
129136
private static ASA_ERROR RunGuidedModeCommand(GuidedModeCommandOptions opts)
130137
{
131138
opts.RunId = opts.RunId?.Trim() ?? DateTime.Now.ToString("o", CultureInfo.InvariantCulture);
@@ -156,7 +163,13 @@ private static ASA_ERROR RunGuidedModeCommand(GuidedModeCommandOptions opts)
156163

157164
RunCollectCommand(collectorOpts);
158165

159-
var analysisFile = string.IsNullOrEmpty(opts.AnalysesFile) ? RuleFile.LoadEmbeddedFilters() : RuleFile.FromFile(opts.AnalysesFile);
166+
var analysisFile = LoadRulesFromFileOrEmbedded(opts.AnalysesFile);
167+
168+
if (!analysisFile.Rules.Any())
169+
{
170+
Log.Warning(Strings.Get("Err_NoRules"));
171+
return ASA_ERROR.INVALID_RULES;
172+
}
160173

161174
var compareOpts = new CompareCommandOptions(firstCollectRunId, secondCollectRunId)
162175
{
@@ -211,6 +224,7 @@ private static ASA_ERROR RunGuidedModeCommand(GuidedModeCommandOptions opts)
211224
{
212225
if (opts is null) { return new ConcurrentDictionary<(RESULT_TYPE, CHANGE_TYPE), List<CompareResult>>(); }
213226
var results = new ConcurrentDictionary<(RESULT_TYPE, CHANGE_TYPE), List<CompareResult>>();
227+
var analysesHash = ruleFile.GetHash();
214228
Parallel.ForEach(collectObjects, monitorResult =>
215229
{
216230
var shellResult = new CompareResult()
@@ -220,7 +234,7 @@ private static ASA_ERROR RunGuidedModeCommand(GuidedModeCommandOptions opts)
220234
};
221235

222236
shellResult.Rules = analyzer.Analyze(ruleFile.Rules, shellResult).ToList();
223-
shellResult.AnalysesHash = ruleFile.GetHash();
237+
shellResult.AnalysesHash = analysesHash;
224238

225239
if (opts.ApplySubObjectRulesToMonitor)
226240
{
@@ -247,8 +261,13 @@ private static ASA_ERROR RunGuidedModeCommand(GuidedModeCommandOptions opts)
247261
private static ASA_ERROR RunVerifyRulesCommand(VerifyOptions opts)
248262
{
249263
var analyzer = new AsaAnalyzer(new AnalyzerOptions(opts.RunScripts));
250-
var ruleFile = string.IsNullOrEmpty(opts.AnalysisFile) ? RuleFile.LoadEmbeddedFilters() : RuleFile.FromFile(opts.AnalysisFile);
251-
var violations = analyzer.EnumerateRuleIssues(ruleFile.GetRules());
264+
var ruleFile = LoadRulesFromFileOrEmbedded(opts.AnalysisFile);
265+
if (!ruleFile.Rules.Any())
266+
{
267+
Log.Warning(Strings.Get("Err_NoRules"));
268+
return ASA_ERROR.INVALID_RULES;
269+
}
270+
var violations = analyzer.EnumerateRuleIssues(ruleFile.Rules);
252271
OAT.Utils.Strings.Setup();
253272
OAT.Utils.Helpers.PrintViolations(violations);
254273
if (violations.Any())
@@ -469,12 +488,19 @@ private static ASA_ERROR RunExportCollectCommand(ExportCollectCommandOptions opt
469488
}
470489
}
471490

491+
var ruleFile = LoadRulesFromFileOrEmbedded(opts.AnalysesFile);
492+
if (!ruleFile.Rules.Any())
493+
{
494+
Log.Warning(Strings.Get("Err_NoRules"));
495+
return ASA_ERROR.INVALID_RULES;
496+
}
497+
472498
Log.Information(Strings.Get("Comparing"), opts.FirstRunId, opts.SecondRunId);
473499

474500
CompareCommandOptions options = new CompareCommandOptions(opts.FirstRunId, opts.SecondRunId)
475501
{
476502
DatabaseFilename = opts.DatabaseFilename,
477-
AnalysesFile = string.IsNullOrEmpty(opts.AnalysesFile) ? RuleFile.LoadEmbeddedFilters() : RuleFile.FromFile(opts.AnalysesFile),
503+
AnalysesFile = ruleFile,
478504
DisableAnalysis = opts.DisableAnalysis,
479505
SaveToDatabase = opts.SaveToDatabase,
480506
RunScripts = opts.RunScripts
@@ -659,10 +685,17 @@ private static ASA_ERROR RunExportMonitorCommand(ExportMonitorCommandOptions opt
659685
return ASA_ERROR.INVALID_ID;
660686
}
661687
}
688+
689+
var ruleFile = LoadRulesFromFileOrEmbedded(opts.AnalysesFile);
690+
if (!ruleFile.Rules.Any())
691+
{
692+
Log.Warning(Strings.Get("Err_NoRules"));
693+
return ASA_ERROR.INVALID_RULES;
694+
}
662695
var monitorCompareOpts = new CompareCommandOptions(null, opts.RunId)
663696
{
664697
DisableAnalysis = opts.DisableAnalysis,
665-
AnalysesFile = string.IsNullOrEmpty(opts.AnalysesFile) ? RuleFile.LoadEmbeddedFilters() : RuleFile.FromFile(opts.AnalysesFile),
698+
AnalysesFile = ruleFile,
666699
ApplySubObjectRulesToMonitor = opts.ApplySubObjectRulesToMonitor,
667700
RunScripts = opts.RunScripts
668701
};
@@ -883,7 +916,8 @@ public static List<BaseCompare> GetComparators()
883916
watch = Stopwatch.StartNew();
884917
var analyzer = new AsaAnalyzer(new AnalyzerOptions(opts.RunScripts));
885918
var platform = DatabaseManager.RunIdToPlatform(opts.SecondRunId);
886-
var violations = analyzer.EnumerateRuleIssues(opts.AnalysesFile.GetRules());
919+
var violations = analyzer.EnumerateRuleIssues(opts.AnalysesFile.Rules);
920+
var analysesHash = opts.AnalysesFile.GetHash();
887921
OAT.Utils.Strings.Setup();
888922
OAT.Utils.Helpers.PrintViolations(violations);
889923
if (violations.Any())
@@ -910,7 +944,7 @@ public static List<BaseCompare> GetComparators()
910944
res.Rules = analyzer.Analyze(selectedRules, res.Base, res.Compare).ToList();
911945
res.Analysis = res.Rules.Count
912946
> 0 ? res.Rules.Max(x => ((AsaRule)x).Flag) : opts.AnalysesFile.DefaultLevels[res.ResultType];
913-
res.AnalysesHash = opts.AnalysesFile.GetHash();
947+
res.AnalysesHash = analysesHash;
914948
});
915949
}
916950
}

0 commit comments

Comments
 (0)