diff --git a/patcherex/cfg_utils.py b/patcherex/cfg_utils.py index 9d265bd..e04a3be 100644 --- a/patcherex/cfg_utils.py +++ b/patcherex/cfg_utils.py @@ -80,7 +80,7 @@ def check_first_instruction(instr): return None try: - succ = ff.startpoint.successors() + outedges = ff.transition_graph.out_edges(ff.startpoint, data=True) except networkx.NetworkXError: return None ends = ff.endpoints @@ -90,19 +90,21 @@ def check_first_instruction(instr): return None if syscall_number == 1: - if len(succ) == 1: - bb1 = succ[0] - if hasattr(bb1,"is_syscall") and bb1.is_syscall: + if len(outedges) == 1: + edge_data = list(outedges)[0][-1] + if edge_data.get("type") == "syscall": return syscall_number return None else: if not is_sane_function(ff): return None - if len(succ) == 2: - bb1,bb2 = succ - if hasattr(bb1,"is_syscall") and bb1.is_syscall: + if len(outedges) == 2: + outedges = list(outedges) + _, bb1, data1 = outedges[0] + _, bb2, data2 = outedges[1] + if data1.get("type") == "syscall": ebb = bb2 - elif hasattr(bb2,"is_syscall") and bb2.is_syscall: + elif data2.get("type") == "syscall": ebb = bb1 else: ebb= None diff --git a/tests/test_cfg.py b/tests/test_cfg.py index 8af86d0..fb1806b 100755 --- a/tests/test_cfg.py +++ b/tests/test_cfg.py @@ -261,7 +261,7 @@ def test_detect_syscall_wrapper(): syscall_wrappers = set([(ff.addr,cfg_utils.detect_syscall_wrapper(backend,ff)) \ for ff in cfg.functions.values() if cfg_utils.detect_syscall_wrapper(backend,ff)!=None]) print("syscall wrappers in CROMU_00071:") - print(map(lambda x:(hex(x[0]),x[1]),syscall_wrappers)) + print(list(map(lambda x:(hex(x[0]),x[1]),syscall_wrappers))) assert syscall_wrappers == legitimate_syscall_wrappers filepath = os.path.join(bin_location, "CROMU_00070") @@ -280,7 +280,7 @@ def test_detect_syscall_wrapper(): syscall_wrappers = set([(ff.addr,cfg_utils.detect_syscall_wrapper(backend,ff)) \ for ff in cfg.functions.values() if cfg_utils.detect_syscall_wrapper(backend,ff)!=None]) print("syscall wrappers in CROMU_00070:") - print(map(lambda x:(hex(x[0]),x[1]),syscall_wrappers)) + print(list(map(lambda x:(hex(x[0]),x[1]),syscall_wrappers))) assert syscall_wrappers == legitimate_syscall_wrappers