Skip to content

Commit 9d8c72c

Browse files
committed
Fix: Fix reachability analysis vulnerability in server-side logic pruner
1 parent 5a794ec commit 9d8c72c

File tree

5 files changed

+88
-50
lines changed

5 files changed

+88
-50
lines changed

src/OTAPI.UnifiedServerProcess/Core/NetworkLogicPruner.cs

Lines changed: 60 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,33 @@ public class NetworkLogicPruner(ModuleDefinition module)
1919
// TODO: support more cases
2020
// readonly FieldDefinition netMode = module.GetType("Terraria.Main").Field("netMode");
2121

22-
public void Prune() {
23-
22+
public void Prune(params string[] skippedTypeFullNames) {
23+
var skipTypes = skippedTypeFullNames.ToHashSet();
2424
foreach (var type in module.GetAllTypes()) {
25-
foreach (var method in type.Methods) {
25+
26+
if (skipTypes.Contains(type.FullName)) {
27+
continue;
28+
}
29+
30+
foreach (var method in type.Methods.ToArray()) {
2631
if (!method.HasBody) {
2732
continue;
2833
}
2934

3035
Dictionary<Instruction, Instruction> switchBlockEnd = [];
3136

32-
bool goingOn = false;
37+
bool anyTargetField = false;
3338
foreach (var inst in method.Body.Instructions) {
3439
if (inst.Operand is not FieldReference fieldReference) {
3540
continue;
3641
}
3742
if (fieldReference.FullName == dedServ.FullName || fieldReference.FullName == skipMenu.FullName) {
38-
goingOn = true;
43+
anyTargetField = true;
3944
break;
4045
}
4146
}
4247

43-
if (!goingOn) {
48+
if (!anyTargetField) {
4449
continue;
4550
}
4651

@@ -118,12 +123,31 @@ public void Prune() {
118123
indexMap[method.Body.Instructions[i]] = i;
119124
}
120125

121-
var data = (-1, -1);
126+
HashSet<int> visited = [];
127+
Stack<(int current, int end1, int end2)> paths = [];
128+
paths.Push((0, -1, -1));
129+
130+
while (paths.TryPop(out var pathDetail)) {
131+
132+
while (pathDetail.current < method.Body.Instructions.Count) {
133+
if (!visited.Add(pathDetail.current)) {
134+
break;
135+
}
136+
137+
var reachableInst = method.Body.Instructions[pathDetail.current];
138+
CanReachCurrentInstruction(method.Body, jumpSites, switchBlockEnd, reachableInst, removes, indexMap, ref pathDetail);
139+
140+
if (pathDetail.end1 == -1 && pathDetail.end2 == -1
141+
&& reachableInst.Operand is Instruction jumpTo
142+
&& pathDetail.current < method.Body.Instructions.Count
143+
&& method.Body.Instructions[pathDetail.current] != jumpTo
144+
&& !visited.Contains(indexMap[jumpTo])) {
122145

123-
for (int i = 0; i < method.Body.Instructions.Count;) {
124-
var reachableInst = method.Body.Instructions[i];
125-
CanReachCurrentInstruction(method.Body, jumpSites, switchBlockEnd, reachableInst, removes, indexMap, ref i, ref data);
126-
reachableInstructions.Add(reachableInst);
146+
paths.Push((indexMap[jumpTo], pathDetail.end1, pathDetail.end2));
147+
}
148+
149+
reachableInstructions.Add(reachableInst);
150+
}
127151
}
128152

129153
foreach (var rm in removes) {
@@ -136,23 +160,25 @@ public void Prune() {
136160
}
137161
}
138162

163+
164+
139165
void CanReachCurrentInstruction(
140166

141167
MethodBody body,
142168
Dictionary<Instruction, List<Instruction>> jumpSites,
143169
Dictionary<Instruction, Instruction> switchBlockToEnd,
144170
Instruction instruction,
145171
List<Instruction> rm,
146-
Dictionary<Instruction, int> indexMap,
172+
Dictionary<Instruction, int> inst2Index,
147173

148-
ref int index, ref (int end_dedServIsTrueBlock, int end_skipMenuIsFalseBlock) data) {
174+
ref (int index, int end_dedServIsTrueBlock, int end_skipMenuIsFalseBlock) data) {
149175

150176

151-
if (index > data.end_dedServIsTrueBlock) {
177+
if (data.index > data.end_dedServIsTrueBlock) {
152178
data.end_dedServIsTrueBlock = -1;
153179
}
154180

155-
if (index > data.end_skipMenuIsFalseBlock) {
181+
if (data.index > data.end_skipMenuIsFalseBlock) {
156182
data.end_skipMenuIsFalseBlock = -1;
157183
}
158184

@@ -162,30 +188,30 @@ void CanReachCurrentInstruction(
162188
var nextInst = instruction.Next;
163189

164190
if (nextInst.OpCode == OpCodes.Brtrue || nextInst.OpCode == OpCodes.Brtrue_S) {
165-
var jumpTarget = (Instruction)nextInst.Operand;
191+
var jumpTo = (Instruction)nextInst.Operand;
166192

167-
var jumpIndex = indexMap[jumpTarget];
168-
index = jumpIndex;
193+
var jumpToIndex = inst2Index[jumpTo];
194+
data.index = jumpToIndex;
169195

170196
if (switchBlockToEnd.TryGetValue(instruction, out var blockEnd)) {
171197
if (blockEnd.OpCode != OpCodes.Ret
172198
&& blockEnd.OpCode != OpCodes.Br
173199
&& blockEnd.OpCode != OpCodes.Br_S) {
174200
blockEnd = blockEnd.Next;
175201
}
176-
var switchEndIndex = indexMap[blockEnd];
177-
if (switchEndIndex < jumpIndex) {
178-
index = switchEndIndex;
202+
var switchEndIndex = inst2Index[blockEnd];
203+
if (switchEndIndex < jumpToIndex) {
204+
data.index = switchEndIndex;
179205
}
180206
rm.Add(instruction);
181207
return;
182208
}
183209

184210
// else block
185211
if (instruction.Previous is not null && (instruction.Previous.OpCode == OpCodes.Br || instruction.Previous.OpCode == OpCodes.Br_S)) {
186-
jumpTarget = (Instruction)instruction.Previous.Operand;
187-
jumpIndex = indexMap[jumpTarget];
188-
data.end_dedServIsTrueBlock = jumpIndex;
212+
jumpTo = (Instruction)instruction.Previous.Operand;
213+
jumpToIndex = inst2Index[jumpTo];
214+
data.end_dedServIsTrueBlock = jumpToIndex;
189215
}
190216

191217
rm.Add(instruction);
@@ -195,7 +221,7 @@ void CanReachCurrentInstruction(
195221
if (nextInst.OpCode == OpCodes.Brfalse || nextInst.OpCode == OpCodes.Brfalse_S) {
196222
var jumpTarget = (Instruction)nextInst.Operand;
197223
if (jumpSites[jumpTarget].Count == 1) {
198-
data.end_dedServIsTrueBlock = indexMap[jumpTarget];
224+
data.end_dedServIsTrueBlock = inst2Index[jumpTarget];
199225
rm.Add(instruction);
200226
rm.Add(nextInst);
201227
}
@@ -207,18 +233,18 @@ void CanReachCurrentInstruction(
207233
if (nextInst.OpCode == OpCodes.Brfalse || nextInst.OpCode == OpCodes.Brfalse_S) {
208234
var jumpTarget = (Instruction)nextInst.Operand;
209235

210-
var jumpIndex = indexMap[jumpTarget];
211-
index = jumpIndex;
236+
var jumpIndex = inst2Index[jumpTarget];
237+
data.index = jumpIndex;
212238

213239
if (switchBlockToEnd.TryGetValue(instruction, out var blockEnd)) {
214240
if (blockEnd.OpCode != OpCodes.Ret
215241
&& blockEnd.OpCode != OpCodes.Br
216242
&& blockEnd.OpCode != OpCodes.Br_S) {
217243
blockEnd = blockEnd.Next;
218244
}
219-
var switchEndIndex = indexMap[blockEnd];
245+
var switchEndIndex = inst2Index[blockEnd];
220246
if (switchEndIndex < jumpIndex) {
221-
index = switchEndIndex;
247+
data.index = switchEndIndex;
222248
}
223249
rm.Add(instruction);
224250
return;
@@ -227,7 +253,7 @@ void CanReachCurrentInstruction(
227253
// else block
228254
if (instruction.Previous is not null && (instruction.Previous.OpCode == OpCodes.Br || instruction.Previous.OpCode == OpCodes.Br_S)) {
229255
jumpTarget = (Instruction)instruction.Previous.Operand;
230-
jumpIndex = indexMap[jumpTarget];
256+
jumpIndex = inst2Index[jumpTarget];
231257
data.end_skipMenuIsFalseBlock = jumpIndex;
232258
}
233259

@@ -238,23 +264,23 @@ void CanReachCurrentInstruction(
238264
if (nextInst.OpCode == OpCodes.Brtrue || nextInst.OpCode == OpCodes.Brtrue_S) {
239265
var jumpTarget = (Instruction)nextInst.Operand;
240266
if (jumpSites[jumpTarget].Count == 1) {
241-
data.end_skipMenuIsFalseBlock = indexMap[jumpTarget];
267+
data.end_skipMenuIsFalseBlock = inst2Index[jumpTarget];
242268
rm.Add(instruction);
243269
rm.Add(nextInst);
244270
}
245271
}
246272
}
247273
}
248274

249-
if (CheckIsJumpOutOfBlock(body, instruction, switchBlockToEnd, rm, indexMap, ref index, data.end_dedServIsTrueBlock)) {
275+
if (CheckIsJumpOutOfBlock(body, instruction, switchBlockToEnd, rm, inst2Index, ref data.index, data.end_dedServIsTrueBlock)) {
250276
return;
251277
}
252278

253-
if (CheckIsJumpOutOfBlock(body, instruction, switchBlockToEnd, rm, indexMap, ref index, data.end_skipMenuIsFalseBlock)) {
279+
if (CheckIsJumpOutOfBlock(body, instruction, switchBlockToEnd, rm, inst2Index, ref data.index, data.end_skipMenuIsFalseBlock)) {
254280
return;
255281
}
256282

257-
index += 1;
283+
data.index += 1;
258284

259285
return;
260286

src/OTAPI.UnifiedServerProcess/Core/PatchCollision.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,11 @@ instruction.Operand is FieldReference field &&
179179
}
180180

181181
HashSet<string> predefined = [
182-
MonoModCommon.Reference.ImportMethod(module, () => Collision.SlopeCollision(default,default,default,default,default,default)).GetIdentifier(),
183-
MonoModCommon.Reference.ImportMethod(module, () => Collision.noSlopeCollision(default,default,default,default,default,default)).GetIdentifier(),
184-
MonoModCommon.Reference.ImportMethod(module, () => Collision.TileCollision(default,default,default,default,default,default,default)).GetIdentifier(),
185-
MonoModCommon.Reference.ImportMethod(module, () => Collision.AdvancedTileCollision(default, default, default, default, default, default, default, default)).GetIdentifier(),
186-
MonoModCommon.Reference.ImportMethod(module, () => default(Player)!.SlopingCollision(default,default)).GetIdentifier(),
182+
MonoModCommon.Reference.Method(() => Collision.SlopeCollision(default,default,default,default,default,default)).GetSimpleIdentifier(),
183+
MonoModCommon.Reference.Method(() => Collision.noSlopeCollision(default,default,default,default,default,default)).GetSimpleIdentifier(),
184+
MonoModCommon.Reference.Method(() => Collision.TileCollision(default, default, default, default, default, default, default)).GetSimpleIdentifier(),
185+
MonoModCommon.Reference.Method(() => Collision.AdvancedTileCollision(default, default, default, default, default, default, default, default)).GetSimpleIdentifier(),
186+
MonoModCommon.Reference.Method(() => default(Player)!.SlopingCollision(default,default)).GetSimpleIdentifier(),
187187
];
188188

189189
foreach (var m in methodsWithVariables.Values.ToArray()) {

src/OTAPI.UnifiedServerProcess/Core/Patching/FieldFilterPatching/InitialFieldModificationProcessor.cs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,9 @@ inst.Operand is FieldReference f
839839
method.Body.ExceptionHandlers.Clear();
840840
}
841841
else {
842-
Dictionary<Instruction, Instruction> instMap = [];
842+
Dictionary<Instruction, Instruction> instOrig2GenMap = [];
843+
844+
instOrig2GenMap[method.Body.Instructions.Last()] = returnInst;
843845

844846
Dictionary<Instruction, List<LoopBlockData>> instToLoopBlocks = [];
845847
foreach (var loopBlock in loopBlocks.Values) {
@@ -935,7 +937,7 @@ static Instruction CloneAndUpdateMap(Dictionary<Instruction, Instruction> instMa
935937

936938
foreach (var init in loopBlock.InitLoopVariable) {
937939
if (addedInsts.Add(init)) {
938-
var cloneInit = CloneAndUpdateMap(instMap, init);
940+
var cloneInit = CloneAndUpdateMap(instOrig2GenMap, init);
939941
MapLocal(method, localMap, init, cloneInit, loopBlock);
940942
ilProcessor.InsertBefore(returnInst, cloneInit);
941943
}
@@ -945,7 +947,7 @@ static Instruction CloneAndUpdateMap(Dictionary<Instruction, Instruction> instMa
945947
if (restInsts.Remove(inst)) {
946948
if (addedInsts.Add(inst)) {
947949
removedInsts.Add(inst);
948-
Instruction clone = CloneAndUpdateMap(instMap, inst);
950+
Instruction clone = CloneAndUpdateMap(instOrig2GenMap, inst);
949951
MapLocal(method, localMap, inst, clone, loopBlock);
950952
ilProcessor.InsertBefore(returnInst, clone);
951953
}
@@ -954,20 +956,20 @@ static Instruction CloneAndUpdateMap(Dictionary<Instruction, Instruction> instMa
954956
if (restInsts.Count == 0) {
955957
foreach (var post in loopBlock.PostLoop) {
956958
if (addedInsts.Add(post)) {
957-
var clonePost = CloneAndUpdateMap(instMap, post);
959+
var clonePost = CloneAndUpdateMap(instOrig2GenMap, post);
958960
MapLocal(method, localMap, post, clonePost, loopBlock);
959961
ilProcessor.InsertBefore(returnInst, clonePost);
960962
}
961963
}
962964
foreach (var cond in loopBlock.LoopCond) {
963965
if (addedInsts.Add(cond)) {
964-
var cloneCond = CloneAndUpdateMap(instMap, cond);
966+
var cloneCond = CloneAndUpdateMap(instOrig2GenMap, cond);
965967
MapLocal(method, localMap, cond, cloneCond, loopBlock);
966968
ilProcessor.InsertBefore(returnInst, cloneCond);
967969
}
968970
}
969-
var mapped = (Instruction)(instMap[loopBlock.LoopCond.Last()].Operand = instMap[loopBlock.JumpToLoopHead].Next);
970-
instMap[mapped] = mapped;
971+
var mapped = (Instruction)(instOrig2GenMap[loopBlock.LoopCond.Last()].Operand = instOrig2GenMap[loopBlock.JumpToLoopHead].Next);
972+
instOrig2GenMap[mapped] = mapped;
971973

972974
checkingLoops.Remove(loopBlock);
973975
processedLoops.Add(loopBlock);
@@ -976,19 +978,19 @@ static Instruction CloneAndUpdateMap(Dictionary<Instruction, Instruction> instMa
976978
}
977979
else if (addedInsts.Add(inst) && extractedStaticInsts.Contains(inst)) {
978980
removedInsts.Add(inst);
979-
Instruction clone = CloneAndUpdateMap(instMap, inst);
981+
Instruction clone = CloneAndUpdateMap(instOrig2GenMap, inst);
980982
MapLocal(method, localMap, inst, clone, null);
981983
ilProcessor.InsertBefore(returnInst, clone);
982984
}
983985
}
984986

985987
foreach (var inst in generated.Body.Instructions) {
986988
if (inst.Operand is Instruction jumpTarget) {
987-
inst.Operand = instMap[jumpTarget];
989+
inst.Operand = instOrig2GenMap[jumpTarget];
988990
}
989991
else if (inst.Operand is Instruction[] jumpTargets) {
990992
for (int i = 0; i < jumpTargets.Length; i++) {
991-
jumpTargets[i] = instMap[jumpTargets[i]];
993+
jumpTargets[i] = instOrig2GenMap[jumpTargets[i]];
992994
}
993995
}
994996
}

src/OTAPI.UnifiedServerProcess/Core/PatchingLogic.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public static class PatchingLogic
1515
public static void Patch(ILogger logger, ModuleDefinition module) {
1616

1717
PatchLogic.PatchCollision(module);
18-
new NetworkLogicPruner(module).Prune();
18+
new NetworkLogicPruner(module).Prune("Terraria.Player");
1919

2020
var analyzers = new AnalyzerGroups(logger, module);
2121
// var cacheHelper = new CacheManager(logger);

src/OTAPI.UnifiedServerProcess/Extensions/MonoModExtensions.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,16 @@ public static string GetIdentifier(this MethodReference method, bool withTypeNam
278278
"(" + string.Join(",", paramStrs) + ")";
279279
}
280280

281+
public static string GetSimpleIdentifier(this System.Reflection.MethodBase method, bool withTypeName = true) {
282+
var type = method.DeclaringType;
283+
if (type is null && withTypeName) {
284+
throw new ArgumentException("DeclaringType is null", nameof(method));
285+
}
286+
var typeName = withTypeName ? method.DeclaringType!.FullName + "." : "";
287+
288+
return typeName + method.Name + "(" + string.Join(",", method.GetParameters().Select(p => p.ParameterType.FullName)) + ")";
289+
}
290+
281291
public static string GetIdentifier(this MethodReference method, bool withTypeName = true, params TypeDefinition[] ignoreParams) {
282292
var originalType = method.DeclaringType;
283293
if (method.DeclaringType is null && withTypeName) {

0 commit comments

Comments
 (0)