diff --git a/plugins/xrefer/backend/__init__.py b/plugins/xrefer/backend/__init__.py index 0aa8132..615a9a1 100644 --- a/plugins/xrefer/backend/__init__.py +++ b/plugins/xrefer/backend/__init__.py @@ -1,6 +1,6 @@ """backend abstraction with pythonic factory pattern.""" -from .base import Address, BackEnd, BackendError, Function, FunctionType, Section, String, StringEncType, Xref, XrefType +from .base import Address, BackEnd, BackendError, Function, FunctionType, Section, SectionType, String, StringEncType, Xref, XrefType, OperandType, Operand, Instruction from .factory import backend_manager, get_backend, list_available_backends from .utils import sample_path @@ -50,6 +50,11 @@ def _ensure_backend_initialized(): def get_current_backend(): """Get the current backend instance, initializing if needed.""" + # First check if backend manager has an active backend + active_backend = backend_manager.get_active_backend() + if active_backend is not None: + return active_backend + # Fall back to legacy initialization _ensure_backend_initialized() return Backend @@ -71,8 +76,13 @@ def get_current_backend(): "String", "Xref", "Section", + "SectionType", "sample_path", "get_indirect_calls", + # operand + "Instruction", + "Operand", + "OperandType", # Base classes "BackEnd", ] diff --git a/plugins/xrefer/backend/base.py b/plugins/xrefer/backend/base.py index 152fc51..8fceed7 100644 --- a/plugins/xrefer/backend/base.py +++ b/plugins/xrefer/backend/base.py @@ -105,6 +105,45 @@ def is_valid(self) -> bool: return self != self.invalid() +class OperandType(Enum): + """Canonical operand categories across backends.""" + IMMEDIATE = auto() + REGISTER = auto() + MEMORY = auto() + RELATIVE = auto() # branch/call rel targets + OTHER = auto() + +@dataclass(frozen=True) +class MemoryOperand: + """Structured memory operand (best-effort, tolerant across tools).""" + base: Optional[str] = None + index: Optional[str] = None + scale: Optional[int] = None + disp: Optional[int] = None + seg: Optional[str] = None + addr_size: Optional[int] = None # in bits if known + +@dataclass(frozen=True) +class Operand: + """Unified operand.""" + type: OperandType + text: str + value: Optional[Address]=None + # reg: Optional[str] = None + # imm: Optional[int] = None + # mem: Optional[MemoryOperand] = None + +@dataclass +class Instruction: + address: Address + # prefixes: Tuple[str, ...] # e.g., ("lock",) or (). TODO: forget for now + mnemonic: str # canonical, lowercased, NO prefixes + operands: Tuple[Operand, ...] + text: str # full display text as shown in tool + + + + @dataclass class BasicBlock: """ @@ -444,9 +483,6 @@ def strings(self, min_length: int = String.MIN_LENGTH) -> Iterator[String]: String objects for each identified string """ ... - pass - - # Symbol Resolution @abstractmethod def get_name_at(self, address: Address) -> str: @@ -734,6 +770,12 @@ def set_function_comment(self, address: Address, comment: str) -> bool: except Exception: return False + def disassemble(self, address: Address) -> "Instruction": + """ + Disassemble a single instruction at `address`. + """ + return self._get_disassembly_impl(address) + # # Backend-Specific Implementation Methods # @@ -757,3 +799,8 @@ def _set_function_comment_impl(self, address: Address, comment: str) -> None: def _path_impl(self) -> str: """Backend-specific implementation for getting binary path.""" ... + + @abstractmethod + def _get_disassembly_impl(self, address: Address) -> Instruction: + """Backend-specific implementation for getting disassembly at a specific address.""" + ... diff --git a/plugins/xrefer/backend/binaryninja/backend.py b/plugins/xrefer/backend/binaryninja/backend.py index 085916e..5e127f0 100644 --- a/plugins/xrefer/backend/binaryninja/backend.py +++ b/plugins/xrefer/backend/binaryninja/backend.py @@ -6,8 +6,9 @@ from typing import Iterator, Optional, Tuple import binaryninja as bn +from binaryninja.enums import LowLevelILOperation -from ..base import Address, BackEnd, BasicBlock, Function, FunctionType, Section, SectionType, String, StringEncType, Xref, XrefType +from ..base import Address, BackEnd, BasicBlock, Function, FunctionType, Operand, OperandType, Section, SectionType, String, StringEncType, Xref, XrefType, Instruction class BinaryNinjaFunction(Function): @@ -94,9 +95,10 @@ def encoding(self) -> StringEncType: class BinaryNinjaXref(Xref): """Simple xref representation.""" - def __init__(self, source: int, target: int): + def __init__(self, source: int, target: int, kind: XrefType = XrefType.UNKNOWN): self._src = source self._dst = target + self._kind = kind @property def source(self) -> Address: @@ -108,7 +110,7 @@ def target(self) -> Address: @property def type(self) -> XrefType: # pragma: no cover - heuristic mapping not critical - return XrefType.UNKNOWN + return self._kind class BinaryNinjaSection(Section): @@ -216,6 +218,21 @@ def __init__(self, bv): super().__init__() self._bv: "bn.BinaryView" = bv + def __getstate__(self): + """Make backend pickle-safe by removing BinaryView/ctypes state. + + Binary Ninja's objects are ctypes-backed and cannot be pickled. We + drop references here so the analyzer state can be serialized. + """ + state = self.__dict__.copy() + # Remove BinaryView (and any other transient analysis handles) from state + state["_bv"] = None + return state + + def __setstate__(self, state): + """Restore state without BinaryView. Caller must re-set if needed.""" + self.__dict__.update(state) + @property def name(self) -> str: """Backend name for language module lookup.""" @@ -233,7 +250,7 @@ def get_function_at(self, address: Address) -> Optional[BinaryNinjaFunction]: funcs = self._bv.get_functions_containing(int(address)) return BinaryNinjaFunction(funcs[0]) if funcs else None - def strings(self, min_length: int = 3) -> Iterator[BinaryNinjaString]: + def strings(self, min_length: int = 5) -> Iterator[BinaryNinjaString]: """ Get all strings in the Binary Ninja file. @@ -243,21 +260,85 @@ def strings(self, min_length: int = 3) -> Iterator[BinaryNinjaString]: Yields: BinaryNinjaString: Strings found in the binary """ + # Only return strings from non-executable memory to align with Ghidra/IDA behavior for s in self._bv.get_strings(length=min_length): - if len(s.value) >= min_length: - yield BinaryNinjaString(s) + if len(s.value) < min_length: + continue + seg = self._bv.get_segment_at(s.start) + if seg and seg.executable: + continue + yield BinaryNinjaString(s) def get_xrefs_to(self, address: Address) -> Iterator[BinaryNinjaXref]: - for ref in self._bv.get_code_refs(int(address)): - yield BinaryNinjaXref(ref.address, int(address)) - for ref in self._bv.get_data_refs(int(address)): - yield BinaryNinjaXref(ref, int(address)) + addr = int(address) + code_refs = list(self._bv.get_code_refs(addr)) + data_refs = list(self._bv.get_data_refs(addr)) + + sym = self._bv.get_symbol_at(addr) + # sym_type = sym.type + # sym_name = sym.full_name if sym else "" + + # If nothing found, try common import normalization variants (IAT/GOT cell) + if not code_refs and not data_refs and sym is not None: + # BN often records refs to the IAT/GOT cell (data ref from import symbol) + iat_cells = list(self._bv.get_data_refs(addr)) + for cell in iat_cells: # limit debug noise + cr2 = list(self._bv.get_code_refs(cell)) + dr2 = list(self._bv.get_data_refs(cell)) + if cr2 or dr2: + code_refs = cr2 + data_refs = dr2 + break + + for ref in code_refs: + yield BinaryNinjaXref(ref.address, addr, self._classify_code_xref(ref.address)) + for ref in data_refs: + yield BinaryNinjaXref(ref, addr, XrefType.DATA_READ) def get_xrefs_from(self, address: Address) -> Iterator[BinaryNinjaXref]: - for dst in self._bv.get_code_refs_from(int(address)): - yield BinaryNinjaXref(int(address), dst) - for dst in self._bv.get_data_refs_from(int(address)): - yield BinaryNinjaXref(int(address), dst) + addr = int(address) + code_dsts = list(self._bv.get_code_refs_from(addr)) + data_dsts = list(self._bv.get_data_refs_from(addr)) + for dst in code_dsts: + yield BinaryNinjaXref(addr, dst, self._classify_code_xref(addr)) + for dst in data_dsts: + yield BinaryNinjaXref(addr, dst, XrefType.DATA_WRITE) + + def _classify_code_xref(self, source_addr: int) -> XrefType: + """Best-effort classification of a code reference originating at `source_addr`.""" + func = self._bv.get_function_at(source_addr) + if func is None: + return XrefType.UNKNOWN + + try: + llil = func.get_low_level_il_at(source_addr) + except Exception: + llil = None + + if llil is None: + return XrefType.UNKNOWN + + op = llil.operation + + if op in ( + LowLevelILOperation.LLIL_CALL, + LowLevelILOperation.LLIL_TAILCALL, + LowLevelILOperation.LLIL_CALL_STACK_ADJUST, + LowLevelILOperation.LLIL_CALL_PARAM, + LowLevelILOperation.LLIL_SYSCALL, + ): + return XrefType.CALL + + if op in ( + LowLevelILOperation.LLIL_JUMP, + LowLevelILOperation.LLIL_JUMP_TO, + LowLevelILOperation.LLIL_GOTO, + LowLevelILOperation.LLIL_RET, + LowLevelILOperation.LLIL_IF, + ): + return XrefType.JUMP + + return XrefType.UNKNOWN def get_name_at(self, address: Address) -> str: sym = self._bv.get_symbol_at(int(address)) @@ -285,24 +366,30 @@ def _get_raw_imports(self) -> Iterator[Tuple[Address, str, str]]: """Get raw import data from Binary Ninja.""" processed_addresses = set() + # Heuristic to detect ELF: presence of a .plt section + is_elf = any(name in self._bv.sections for name in (".plt", ".plt.got", ".rela.plt")) + for ext_loc in self._bv.get_external_locations(): source_symbol = ext_loc.source_symbol symbol_address = source_symbol.address data_refs = list(self._bv.get_data_refs(symbol_address)) if len(data_refs) != 1: - print(f"Warning: Symbol at {hex(symbol_address)} has {len(data_refs)} data references, expected 1") - if len(data_refs) == 0: - continue - # Use the data reference address (IAT/GOT entry) - address = data_refs[0] if data_refs else symbol_address - if address in processed_addresses: + print(f"rand0m: {{'bn.import': 'warn', 'sym': '{symbol_address:#x}', 'name': '{source_symbol.raw_name}', 'data_refs': {len(data_refs)}}}") + # Prefer the callable stub if Binary Ninja lifted one; otherwise fall back to IAT/GOT cell. + target_addr = symbol_address + if not self._bv.get_function_at(symbol_address) and data_refs: + target_addr = data_refs[0] + iat_addr = data_refs[0] if data_refs else None + if target_addr in processed_addresses: continue - processed_addresses.add(address) + processed_addresses.add(target_addr) target_name = ext_loc.target_symbol if ext_loc.has_target_symbol else source_symbol.raw_name - module_name = ext_loc.library.name.split("/")[-1] if ext_loc.library else "unknown" + module_name = ext_loc.library.name.split("/")[-1] if ext_loc.library else ("GLIBC" if is_elf else "unknown") + + print(f"rand0m: {{'bn.import': 'map', 'source_sym': '{source_symbol.raw_name}', 'target_name': '{target_name}', 'module': '{module_name}', 'sym_addr': '{symbol_address:#x}', 'iat_addr': '{iat_addr:#x}'}}") # Yield raw import data: (address, function_name, module_name) - yield (Address(address), target_name, module_name) + yield (Address(target_addr), target_name, module_name) def get_exports(self) -> Iterator[Tuple[str, Address]]: """Get all exports from the Binary Ninja binary.""" @@ -365,3 +452,123 @@ def _set_function_comment_impl(self, address: Address, comment: str) -> None: # string_ref = MockStringReference(int(address), content, len(content), encoding) # return BinaryNinjaString(string_ref) + def _get_disassembly_impl(self, address: Address) -> Instruction: + """Backend-specific implementation for getting disassembly at a specific address.""" + ea = int(address) + + # Full disassembly text for this instruction (Binary Ninja formatted) + text = self._bv.get_disassembly(ea) + + # Get tokens for ONLY this instruction at `ea` using the architecture + # Returns (List[InstructionTextToken], length) + inst_len = self._bv.get_instruction_length(ea) + # Read a safe number of bytes for decoding this instruction + data = self._bv.read(ea, inst_len if inst_len and inst_len > 0 else 16) or b"" + tokens, _ = self._bv.arch.get_instruction_text(data, ea) + + # Extract mnemonic from tokens + mnemonic = "" + for tok in tokens: + if tok.type == bn.InstructionTextTokenType.InstructionToken: + mnemonic = tok.text.strip().lower() + break + + # Split operand tokens by OperandSeparatorToken to mimic IDA operand indexing + operands_list: list[Operand] = [] + collecting = False + current: list = [] + for tok in tokens: + if tok.type == bn.InstructionTextTokenType.InstructionToken: + # Start collecting after mnemonic + collecting = True + continue + if not collecting: + continue + if tok.type == bn.InstructionTextTokenType.OperandSeparatorToken: + if current: + # finalize current operand + ttypes = {t.type for t in current} + g_text = "".join(t.text for t in current).strip() + kind: OperandType + if ( + bn.InstructionTextTokenType.BeginMemoryOperandToken in ttypes + or ( + bn.InstructionTextTokenType.CodeRelativeAddressToken in ttypes + and any(t.type == bn.InstructionTextTokenType.BraceToken and t.text == '[' for t in current) + ) + ): + kind = OperandType.MEMORY + elif ( + bn.InstructionTextTokenType.IntegerToken in ttypes + or bn.InstructionTextTokenType.PossibleAddressToken in ttypes + or bn.InstructionTextTokenType.CodeRelativeAddressToken in ttypes + ): + kind = OperandType.IMMEDIATE + elif bn.InstructionTextTokenType.RegisterToken in ttypes: + kind = OperandType.REGISTER + else: + kind = OperandType.OTHER + + val = None + for t in current: + if t.type in ( + bn.InstructionTextTokenType.IntegerToken, + bn.InstructionTextTokenType.PossibleAddressToken, + bn.InstructionTextTokenType.CodeRelativeAddressToken, + ): + try: + val = Address(int(t.value)) + except Exception: + val = None + break + operands_list.append(Operand(type=kind, text=g_text, value=val)) + current = [] + continue + current.append(tok) + + # Flush the last operand if any + if current: + ttypes = {t.type for t in current} + g_text = "".join(t.text for t in current).strip() + if ( + bn.InstructionTextTokenType.BeginMemoryOperandToken in ttypes + or ( + bn.InstructionTextTokenType.CodeRelativeAddressToken in ttypes + and any(t.type == bn.InstructionTextTokenType.BraceToken and t.text == '[' for t in current) + ) + ): + kind = OperandType.MEMORY + elif ( + bn.InstructionTextTokenType.IntegerToken in ttypes + or bn.InstructionTextTokenType.PossibleAddressToken in ttypes + or bn.InstructionTextTokenType.CodeRelativeAddressToken in ttypes + ): + kind = OperandType.IMMEDIATE + elif bn.InstructionTextTokenType.RegisterToken in ttypes: + kind = OperandType.REGISTER + else: + kind = OperandType.OTHER + + val = None + for t in current: + if t.type in ( + bn.InstructionTextTokenType.IntegerToken, + bn.InstructionTextTokenType.PossibleAddressToken, + bn.InstructionTextTokenType.CodeRelativeAddressToken, + ): + try: + val = Address(int(t.value)) + except Exception: + val = None + break + operands_list.append(Operand(type=kind, text=g_text, value=val)) + if not mnemonic: + mnemonic = (text.split()[0].lower() if text else "") + + ins = Instruction( + address=Address(ea), + mnemonic=mnemonic, + operands=tuple(operands_list), + text=text + ) + return ins diff --git a/plugins/xrefer/backend/factory.py b/plugins/xrefer/backend/factory.py index 0643656..76d23e2 100644 --- a/plugins/xrefer/backend/factory.py +++ b/plugins/xrefer/backend/factory.py @@ -77,11 +77,10 @@ def is_available(self) -> bool: """Check if Ghidra is available.""" return importlib.util.find_spec("pyghidra") is not None - def create_backend(self, **kwargs) -> BackEnd: + def create_backend(self, program=None, **kwargs) -> BackEnd: """Create Ghidra backend instance.""" from .ghidra.backend import GhidraBackend - - return GhidraBackend() + return GhidraBackend(program=program) class BackendManager: diff --git a/plugins/xrefer/backend/ghidra/backend.py b/plugins/xrefer/backend/ghidra/backend.py index e5907a4..599e15c 100644 --- a/plugins/xrefer/backend/ghidra/backend.py +++ b/plugins/xrefer/backend/ghidra/backend.py @@ -1,21 +1,25 @@ import logging from collections.abc import Iterator -from ..base import Address, BackEnd, BackendError, BasicBlock, Function, FunctionType, InvalidAddressError, Section, SectionType, String, StringEncType, Xref, XrefType - -# Global reference to getCurrentProgram function - will be set by use_backend.py -getCurrentProgram = None +from ..base import Address, BackEnd, BackendError, BasicBlock, Function, FunctionType, Instruction, InvalidAddressError, Operand, OperandType, Section, SectionType, String, StringEncType, Xref, XrefType class GhidraFunction(Function): - def __init__(self, ghidra_func) -> None: + def __init__(self, ghidra_func, backend) -> None: """Initialize with Ghidra function object.""" if ghidra_func is None: raise ValueError("Ghidra function cannot be None") + if backend is None: + raise ValueError("Backend cannot be None") self._func = ghidra_func + self._backend = backend self._name: str | None = None self._function_type: FunctionType | None = None + def _get_program(self): + """Get program object from backend.""" + return self._backend._get_actual_program() + @property def start(self) -> Address: """Get function start address.""" @@ -37,7 +41,8 @@ def name(self, value: str) -> None: if not value: raise ValueError("Function name cannot be empty") try: - self._func.setName(value, None) + from ghidra.program.model.symbol import SourceType + self._func.setName(value, SourceType.USER_DEFINED) self._name = value except Exception as e: raise BackendError(f"Failed to set function name: {e}") from e @@ -66,7 +71,7 @@ def is_thunk(self) -> bool: def contains(self, address: Address) -> bool: """Check if the address is within the function.""" # Use current program's address factory - program = getCurrentProgram() + program = self._get_program() addr_factory = program.getAddressFactory() addr_value = address.value if isinstance(address, Address) else int(address) ghidra_addr = addr_factory.getAddress(f"{addr_value:x}") @@ -78,18 +83,17 @@ def basic_blocks(self) -> Iterator[BasicBlock]: # Get basic blocks using program model from ghidra.program.model.block import BasicBlockModel - program = getCurrentProgram() + program = self._get_program() block_model = BasicBlockModel(program) blocks = block_model.getCodeBlocksContaining(self._func.getBody(), None) while blocks.hasNext(): block = blocks.next() - yield BasicBlock(Address(block.getMinAddress().getOffset()), - Address(block.getMaxAddress().getOffset() + 1)) + yield BasicBlock(Address(block.getMinAddress().getOffset()), Address(block.getMaxAddress().getOffset() + 1)) def _is_export(self) -> bool: """Return True if the function is exported from the binary.""" - program = getCurrentProgram() + program = self._get_program() symbol_table = program.getSymbolTable() ghidra_addr = self._func.getEntryPoint() symbols = symbol_table.getSymbols(ghidra_addr) @@ -223,21 +227,70 @@ def perm(self) -> str: class GhidraBackend(BackEnd): """Ghidra backend implementation.""" - def __init__(self) -> None: - """Initialize Ghidra backend.""" + def __init__(self, program=None) -> None: + """Initialize Ghidra backend. + + Args: + program: Optional Ghidra program object. + """ super().__init__() - self._program = None + self._program = program self._addr_factory = None + # Set up address factory if program is provided + if program is not None: + self._addr_factory = program.getAddressFactory() + + @property + def program(self): + """Get the current Ghidra program.""" + if self._program is None: + self._ensure_program_loaded() + return self._program + + def __getstate__(self): + """Custom pickle state - exclude unpicklable program object.""" + state = self.__dict__.copy() + # Remove the unpicklable program reference + state["_program"] = None + state["_addr_factory"] = None + return state + + def __setstate__(self, state): + """Custom unpickle state - restore without program object.""" + self.__dict__.update(state) + # Program will need to be re-set after unpickling + + @program.setter + def program(self, program) -> None: + """Set the current Ghidra program.""" + if program is None: + raise ValueError("Program cannot be None") + self._program = program + self._addr_factory = program.getAddressFactory() + def _ensure_program_loaded(self): """Ensure program is loaded and cached.""" if self._program is None: - self._program = getCurrentProgram() - if self._program is None: - raise BackendError("No program is currently loaded in Ghidra") - # Cache address factory for performance + raise BackendError("No program is currently loaded in Ghidra") + + def _get_actual_program(self): + """Get the actual program object.""" + self._ensure_program_loaded() + + # Set up address factory if not already done + if self._addr_factory is None and self._program: self._addr_factory = self._program.getAddressFactory() + return self._program + + def _is_executable_address(self, ghidra_addr) -> bool: + """Return True if address is in an executable memory block.""" + program = self._get_actual_program() + mem = program.getMemory() + block = mem.getBlock(ghidra_addr) + return bool(block and block.isExecute()) + @property def name(self) -> str: """Backend name for language module lookup.""" @@ -247,98 +300,142 @@ def name(self) -> str: def image_base(self) -> Address: """Get image base address where binary is loaded.""" try: - self._ensure_program_loaded() - return Address(self._program.getImageBase().getOffset()) + program = self._get_actual_program() + return Address(program.getImageBase().getOffset()) except Exception as e: raise BackendError(f"Failed to get image base: {e}") def functions(self) -> Iterator[Function]: """Iterate over all functions in the binary.""" - self._ensure_program_loaded() - function_manager = self._program.getFunctionManager() + program = self._get_actual_program() + function_manager = program.getFunctionManager() for func in function_manager.getFunctions(True): - if func is not None: - yield GhidraFunction(func) + if func is None: + continue + # Match IDA semantics more closely: include all non-external + # functions discovered by Ghidra (including thunks and small + # helper stubs). External functions map to imports and are not + # yielded as code functions in IDA either. + if func.isExternal(): + continue + yield GhidraFunction(func, self) def get_function_at(self, address: Address) -> Function | None: """Get function containing the specified address.""" - self._ensure_program_loaded() + program = self._get_actual_program() # Use cached address factory for performance addr_value = address.value if isinstance(address, Address) else int(address) ghidra_addr = self._addr_factory.getAddress(f"{addr_value:x}") - function_manager = self._program.getFunctionManager() + function_manager = program.getFunctionManager() ghidra_func = function_manager.getFunctionContaining(ghidra_addr) if ghidra_func: - return GhidraFunction(ghidra_func) + # Normalize: ignore EXTERNAL functions only + if ghidra_func.isExternal(): + return None + return GhidraFunction(ghidra_func, self) return None def strings(self, min_length: int = String.MIN_LENGTH) -> Iterator[String]: """Iterate over all strings in the binary.""" - self._ensure_program_loaded() - listing = self._program.getListing() + program = self._get_actual_program() + listing = program.getListing() data_iter = listing.getDefinedData(True) while data_iter.hasNext(): data = data_iter.next() - data_type = data.getDataType() - - # Check if it's a string type (call hasStringValue on data, not data_type) - if data.hasStringValue(): - value = data.getValue() - if value and isinstance(value, str) and len(value) >= min_length: - addr = Address(data.getAddress().getOffset()) - # Determine encoding - if "unicode" in data_type.getName().lower() or "wide" in data_type.getName().lower(): - encoding = StringEncType.UTF16 - else: - encoding = StringEncType.ASCII - - yield GhidraString(addr, value, len(value), encoding) - - # Search for additional undefined strings - memory = self._program.getMemory() + if not data.hasStringValue(): + continue + value = data.getValue() + if not (value and isinstance(value, str) and len(value) >= min_length): + continue + addr = Address(data.getAddress().getOffset()) + dt_name = data.getDataType().getName().lower() + enc = StringEncType.UTF16 if ("unicode" in dt_name or "wide" in dt_name) else StringEncType.ASCII + yield GhidraString(addr, value, len(value), enc) + + # 2) Scan readable, initialized, non-executable blocks for raw ASCII/UTF-16LE strings + memory = program.getMemory() for block in memory.getBlocks(): - if block.isInitialized() and not block.isVolatile(): - yield from self._search_strings_in_block(block, min_length) - - def _search_strings_in_block(self, block, min_length: int) -> Iterator[String]: - """Search for strings in a memory block.""" - - start = block.getStart() - end = block.getEnd() - - current = start - while current.compareTo(end) < 0: - try: - listing = self._program.getListing() - data = listing.getDataAt(current) - if data and data.hasStringValue(): - str_data = data.getValue() - else: - str_data = None - if str_data and len(str_data) >= min_length: - yield GhidraString(Address(current.getOffset()), str(str_data), len(str_data), StringEncType.ASCII) - current = current.add(len(str_data)) - else: - current = current.add(1) - except Exception: - # Skip on any error and continue scanning - current = current.add(1) + if not block.isInitialized() or block.isVolatile() or block.isExecute() or not block.isRead(): + continue + yield from self._scan_block_for_strings(block, min_length) + + def _scan_block_for_strings(self, block, min_length: int) -> Iterator[String]: + """Detect ASCII and UTF-16LE strings directly from block bytes.""" + start_off = int(block.getStart().getOffset()) + end_off = int(block.getEnd().getOffset()) + size = end_off - start_off + 1 + + try: + from ghidra.program.flatapi import FlatProgramAPI + flat = FlatProgramAPI(self._get_actual_program()) + raw = flat.getBytes(block.getStart(), size) + buf = bytes(raw) + except Exception: + # Fallback to small-chunk reads via read_bytes + chunk = self.read_bytes(Address(start_off), size) + if not chunk: + return + buf = chunk + + def is_printable(b: int) -> bool: + return 32 <= b <= 126 + + visited = set() + + # ASCII scan + i = 0 + while i < len(buf): + if is_printable(buf[i]): + j = i + while j < len(buf) and is_printable(buf[j]): + j += 1 + if j - i >= min_length: + ea = start_off + i + if ea not in visited: + s = buf[i:j].decode('ascii', errors='ignore') + yield GhidraString(Address(ea), s, len(s), StringEncType.ASCII) + visited.add(ea) + i = j + 1 + else: + i += 1 + + # UTF-16LE scan (simple pattern: printable ASCII with null bytes) + i = 0 + while i+1 < len(buf): + # require little-endian wide char: printable, then 0x00 + if is_printable(buf[i]) and buf[i+1] == 0: + j = i + run = 0 + chars = [] + while j+1 < len(buf) and is_printable(buf[j]) and buf[j+1] == 0: + chars.append(chr(buf[j])) + run += 1 + j += 2 + if run >= min_length: + ea = start_off + i + if ea not in visited: + s = ''.join(chars) + yield GhidraString(Address(ea), s, len(s), StringEncType.UTF16) + visited.add(ea) + i = j + 2 + else: + i += 2 def get_name_at(self, address: Address) -> str: """Get symbol name at the specified address.""" - self._ensure_program_loaded() + program = self._get_actual_program() # Use program's address factory instead of gl.resolve() addr_value = address.value if isinstance(address, Address) else int(address) ghidra_addr = self._addr_factory.getAddress(f"{addr_value:x}") - symbol_table = self._program.getSymbolTable() + symbol_table = program.getSymbolTable() symbol = symbol_table.getPrimarySymbol(ghidra_addr) return symbol.getName() if symbol else "" def get_address_for_name(self, name: str) -> Address | None: """Get address for the specified symbol name.""" - self._ensure_program_loaded() - symbol_table = self._program.getSymbolTable() + program = self._get_actual_program() + symbol_table = program.getSymbolTable() symbols = symbol_table.getSymbols(name) if symbols and symbols.hasNext(): symbol = symbols.next() @@ -347,9 +444,9 @@ def get_address_for_name(self, name: str) -> Address | None: def get_xrefs_to(self, address: Address) -> Iterator[Xref]: """Get all references TO the specified address.""" - self._ensure_program_loaded() + program = self._get_actual_program() # Use program's reference manager - ref_manager = self._program.getReferenceManager() + ref_manager = program.getReferenceManager() addr_value = address.value if isinstance(address, Address) else int(address) ghidra_addr = self._addr_factory.getAddress(f"{addr_value:x}") @@ -360,9 +457,9 @@ def get_xrefs_to(self, address: Address) -> Iterator[Xref]: def get_xrefs_from(self, address: Address) -> Iterator[Xref]: """Get all references FROM the specified address.""" - self._ensure_program_loaded() + program = self._get_actual_program() # Use program's reference manager - ref_manager = self._program.getReferenceManager() + ref_manager = program.getReferenceManager() addr_value = address.value if isinstance(address, Address) else int(address) ghidra_addr = self._addr_factory.getAddress(f"{addr_value:x}") @@ -406,22 +503,52 @@ def _convert_ref_type(self, ghidra_ref_type) -> XrefType: def read_bytes(self, address: Address, size: int) -> bytes | None: """Read raw bytes from the specified address.""" - self._ensure_program_loaded() + program = self._get_actual_program() # Use program's address factory instead of gl.resolve() addr_value = address.value if isinstance(address, Address) else int(address) ghidra_addr = self._addr_factory.getAddress(f"{addr_value:x}") - # Use FlatProgramAPI for simplified byte reading - from ghidra.program.flatapi import FlatProgramAPI + # # Use FlatProgramAPI for simplified byte reading + # from ghidra.program.flatapi import FlatProgramAPI + + # flat_api = FlatProgramAPI(program) + # buffer = flat_api.getBytes(ghidra_addr, size) + # return bytes(buffer) + """ v2 """ + # Guard against invalid or non-readable regions and partial reads. + try: + memory = program.getMemory() + block = memory.getBlock(ghidra_addr) + if block is None or not block.isRead(): + return None + + # Compute how many bytes remain in this block from the start address. + remaining = int(block.getEnd().getOffset() - ghidra_addr.getOffset() + 1) + if remaining <= 0: + return None + + # For consistent semantics with IDA/Binary Ninja backends, only + # return data if the full requested size is available. + if remaining < size: + return None + + from ghidra.program.flatapi import FlatProgramAPI + flat_api = FlatProgramAPI(program) + buffer = flat_api.getBytes(ghidra_addr, size) + return bytes(buffer) if buffer is not None else None + except Exception as e: + import logging + logger = logging.getLogger(__name__) + logger.error(f"Failed to read bytes from {address}: {e}") + # Match other backends: on any memory access issue, return None + # instead of raising, so higher-level scanners can skip gracefully. + return None - flat_api = FlatProgramAPI(self._program) - buffer = flat_api.getBytes(ghidra_addr, size) - return bytes(buffer) def instructions(self, start: Address, end: Address) -> Iterator[Address]: """Iterate over instruction addresses in the specified range.""" - self._ensure_program_loaded() - listing = self._program.getListing() + program = self._get_actual_program() + listing = program.getListing() start_value = start.value if isinstance(start, Address) else int(start) end_value = end.value if isinstance(end, Address) else int(end) start_addr = self._addr_factory.getAddress(f"{start_value:x}") @@ -437,46 +564,96 @@ def instructions(self, start: Address, end: Address) -> Iterator[Address]: def _get_sections_impl(self) -> Iterator[Section]: """Iterate over all memory sections.""" - self._ensure_program_loaded() - memory = self._program.getMemory() + program = self._get_actual_program() + memory = program.getMemory() for block in memory.getBlocks(): yield GhidraSection(block) def get_section_by_name(self, name: str) -> Section | None: """Get section by name.""" - self._ensure_program_loaded() - memory = self._program.getMemory() + program = self._get_actual_program() + memory = program.getMemory() block = memory.getBlock(name) if block: return GhidraSection(block) return None def _get_raw_imports(self) -> Iterator[tuple[Address, str, str]]: - """Get raw import data from backend.""" - self._ensure_program_loaded() - import ghidra.program.model.symbol + """Get raw import data from backend. + + For ELF binaries, external references often appear as THUNK references + (e.g., PLT stubs). To better match IDA/Binary Ninja behavior, we accept + both data-like references (DATA/READ/DATA_IND) and THUNK references, + preferring data-like when available. + """ + program = self._get_actual_program() + import ghidra.program.model.symbol as symbol_module + + symbol_table: symbol_module.SymbolTable = program.getSymbolTable() + em: symbol_module.ExternalManager = program.getExternalManager() - symbol_table: ghidra.program.model.symbol.SymbolTable = self._program.getSymbolTable() - em: ghidra.program.model.symbol.ExternalManager = self._program.getExternalManager() + seen_addrs: set[int] = set() + + # Detect executable format for module normalization + try: + fmt = (program.getExecutableFormat() or "").upper() + is_elf = "ELF" in fmt + except Exception: + is_elf = False for symbol in symbol_table.getExternalSymbols(): - symbol: ghidra.program.model.symbol.Symbol - # Get the external location using the external manager - external_loc: ghidra.program.model.symbol.ExternalLocation = em.getExternalLocation(symbol) - function_name = external_loc.getLabel() or "" - library_name = external_loc.getLibraryName() or "" - # Find references to this external symbol + # Resolve external location for name/library if available + external_loc: symbol_module.ExternalLocation = em.getExternalLocation(symbol) + + function_name = "" + library_name = "" + if external_loc is not None: + function_name = external_loc.getLabel() or "" + library_name = external_loc.getLibraryName() or "" + # Normalize EXTERNAL library placeholder to empty so we can apply ELF default + lib_str = library_name.strip() + if lib_str.upper().startswith("") or lib_str.lower() == "unknown": + library_name = "" + + if not function_name: + function_name = symbol.getName() or "" + + # Collect references and classify refs = symbol.getReferences() - if refs: - for ref in refs: - if ref.referenceType.toString() == "DATA": - addr = Address(ref.getFromAddress().getOffset()) - yield (addr, function_name, library_name) + if not refs: + continue + + RefType = symbol_module.RefType + data_like: list = [] + thunk_like: list = [] + for ref in refs: + rtype = ref.getReferenceType() + # Prefer data-like references that point to IAT/GOT entries + if rtype in (RefType.DATA, RefType.READ, RefType.DATA_IND): + data_like.append(ref) + elif rtype == RefType.THUNK or str(rtype) == "THUNK": + thunk_like.append(ref) + + # Prefer data-like references (IAT/GOT) for stability; fallback to thunk/code + chosen = data_like[0] if data_like else (thunk_like[0] if thunk_like else None) + if not chosen: + continue + + from_addr = chosen.getFromAddress() + addr_int = from_addr.getOffset() + + if addr_int in seen_addrs: + continue + seen_addrs.add(addr_int) + + # Normalize module for ELF when missing + module_out = library_name or ("GLIBC" if is_elf else "unknown") + yield (Address(addr_int), function_name, module_out) def get_exports(self) -> Iterator[tuple[str, Address]]: """Get all exported symbols from the binary.""" - self._ensure_program_loaded() - symbol_table = self._program.getSymbolTable() + program = self._get_actual_program() + symbol_table = program.getSymbolTable() # Iterate through all symbols and find exports for symbol in symbol_table.getAllSymbols(True): @@ -487,14 +664,14 @@ def get_exports(self) -> Iterator[tuple[str, Address]]: def _add_user_xref_impl(self, source: Address, target: Address) -> None: """Backend-specific implementation for adding user cross-references.""" - self._ensure_program_loaded() + program = self._get_actual_program() # Import locally to avoid module-level dependency issues import ghidra.program.model.symbol as symbol_module RefType = symbol_module.RefType SourceType = symbol_module.SourceType - ref_manager = self._program.getReferenceManager() + ref_manager = program.getReferenceManager() source_addr = self._addr_factory.getAddress(f"{source.value:x}") target_addr = self._addr_factory.getAddress(f"{target.value:x}") @@ -502,8 +679,8 @@ def _add_user_xref_impl(self, source: Address, target: Address) -> None: def _set_comment_impl(self, address: Address, comment: str) -> None: """Backend-specific implementation for setting comments.""" - self._ensure_program_loaded() - listing = self._program.getListing() + program = self._get_actual_program() + listing = program.getListing() ghidra_addr = self._addr_factory.getAddress(f"{address.value:x}") # Use numeric constant instead of importing the class EOL_COMMENT = 0 @@ -511,13 +688,13 @@ def _set_comment_impl(self, address: Address, comment: str) -> None: def _set_function_comment_impl(self, address: Address, comment: str) -> None: """Backend-specific implementation for setting function comments.""" - self._ensure_program_loaded() + program = self._get_actual_program() # Use cached address factory ghidra_addr = self._addr_factory.getAddress(f"{address.value:x}") - function_manager = self._program.getFunctionManager() + function_manager = program.getFunctionManager() func = function_manager.getFunctionContaining(ghidra_addr) if func: - listing = self._program.getListing() + listing = program.getListing() # Use numeric constant instead of importing the class PLATE_COMMENT = 1 listing.setComment(func.getEntryPoint(), PLATE_COMMENT, comment) @@ -526,20 +703,106 @@ def _set_function_comment_impl(self, address: Address, comment: str) -> None: def _path_impl(self) -> str: """Backend-specific implementation for getting binary path.""" - self._ensure_program_loaded() try: - executable_path = self._program.getExecutablePath() + program = self._get_actual_program() + executable_path = program.getExecutablePath() if executable_path: return executable_path # Fallback to program name - return self._program.getName() + return program.getName() except Exception as e: raise BackendError(f"Failed to get binary path: {e}") def _binary_hash_impl(self) -> str: """Compute SHA256 hash of the binary file.""" - self._ensure_program_loaded() - sha256 = self._program.getExecutableSHA256() + program = self._get_actual_program() + sha256 = program.getExecutableSHA256() if sha256: return sha256 raise BackendError("Failed to compute binary hash") + + def _get_disassembly_impl(self, address: Address) -> Instruction: + """Disassemble a single instruction at `address` using Ghidra APIs.""" + program = self._get_actual_program() + listing = program.getListing() + + ea = int(address) + gh_addr = self._addr_factory.getAddress(f"{ea:x}") + + inst = listing.getInstructionAt(gh_addr) + if inst is None: + raise BackendError(f"No instruction at address 0x{ea:x}") + + text = inst.toString() + mnem = inst.getMnemonicString().lower() + + # Import Java classes for operand inspection + from ghidra.program.model.scalar import Scalar as GScalar + from ghidra.program.model.address import Address as GAddress + from ghidra.program.model.lang import Register as GRegister + + operands: list[Operand] = [] + num_ops = inst.getNumOperands() + # Pre-fetch references for value recovery (e.g., RIP-relative immediates) + inst_refs = list(inst.getReferencesFrom()) + for i in range(num_ops): + op_text = inst.getDefaultOperandRepresentation(i) + objs = inst.getOpObjects(i) + + is_mem = ("[" in op_text and "]" in op_text) + has_reg = any(isinstance(o, GRegister) for o in objs) + has_addr = any(isinstance(o, GAddress) for o in objs) + has_scalar = any(isinstance(o, GScalar) for o in objs) + + if is_mem: + op_kind = OperandType.MEMORY + elif has_reg and not (has_addr or has_scalar): + op_kind = OperandType.REGISTER + elif has_addr or has_scalar: + op_kind = OperandType.IMMEDIATE + else: + op_kind = OperandType.OTHER + + val = None + if op_kind == OperandType.MEMORY: + # Only treat embedded absolute addresses as values for memory operands + if has_addr: + for o in objs: + if isinstance(o, GAddress): + val = Address(int(o.getOffset())) + break + elif op_kind == OperandType.IMMEDIATE: + if has_addr: + for o in objs: + if isinstance(o, GAddress): + val = Address(int(o.getOffset())) + break + elif has_scalar: + for o in objs: + if isinstance(o, GScalar): + # Keep only non-negative immediates as Address values + uv = int(o.getValue()) + if uv is not None and uv >= 0: + val = Address(uv) + break + + # Use instruction references to recover operand target addresses + # when operand objects do not expose a GAddress (e.g., LEA with + # RIP-relative immediate shown as Scalar). + if val is None and inst_refs: + for ref in inst_refs: + if ref.getOperandIndex() == i: + to_addr = ref.getToAddress() + if to_addr is not None and not to_addr.isStackAddress(): + val = Address(int(to_addr.getOffset())) + break + + operands.append(Operand(type=op_kind, text=op_text, value=val)) + + ins = Instruction( + address=Address(ea), + mnemonic=mnem, + operands=tuple(operands), + text=text + ) + return ins diff --git a/plugins/xrefer/backend/ida/backend.py b/plugins/xrefer/backend/ida/backend.py index 7792b7e..3281f4e 100644 --- a/plugins/xrefer/backend/ida/backend.py +++ b/plugins/xrefer/backend/ida/backend.py @@ -12,7 +12,7 @@ import idautils import idc -from ..base import Address, BackEnd, BackendError, BasicBlock, Function, FunctionType, Section, SectionType, String, StringEncType, Xref, XrefType +from ..base import Address,BackEnd,BackendError,BasicBlock,Function,FunctionType,Section,SectionType,String,StringEncType,Xref,XrefType,Instruction,Operand,OperandType class IDAFunction(Function): @@ -392,3 +392,53 @@ def _set_comment_impl(self, address: Address, comment: str) -> None: def _set_function_comment_impl(self, address: Address, comment: str) -> None: """Set function comment in IDA.""" idc.set_func_cmt(int(address), comment, 0) + + def _get_disassembly_impl(self, address: Address) -> Instruction: + """Backend-specific implementation for getting disassembly at a specific address.""" + ea = int(address) + + # Full disassembly text and mnemonic + text = idc.generate_disasm_line(ea, 0) or "" + mnem = (idc.print_insn_mnem(ea) or "").lower() + + # Collect operands with best-effort typing + operands: list[Operand] = [] + for i in range(8): # x86/x64 has max 4; use 8 as a safe cap + op_type_id = idc.get_operand_type(ea, i) + if op_type_id == idc.o_void: + break + + op_text = idc.print_operand(ea, i) or "" + op_kind = OperandType.OTHER + op_value = None + + try: + if op_type_id == idc.o_imm: + op_kind = OperandType.IMMEDIATE + val = idc.get_operand_value(ea, i) + op_value = Address(int(val)) + elif op_type_id == idc.o_reg: + op_kind = OperandType.REGISTER + elif op_type_id in (idc.o_mem,): + op_kind = OperandType.MEMORY + val = idc.get_operand_value(ea, i) + op_value = Address(int(val)) + elif op_type_id in (idc.o_phrase, idc.o_displ): + # Memory with computed address; keep value None (use xrefs to resolve) + op_kind = OperandType.MEMORY + else: + op_kind = OperandType.OTHER + except Exception: + op_kind = OperandType.OTHER + op_value = None + + operands.append(Operand(type=op_kind, text=op_text, value=op_value)) + # print(f"{Address(ea)}: {mnem}|\tOperand[{i}]: {operands[-1]}") + + ins = Instruction( + address=Address(ea), + mnemonic=mnem, + operands=tuple(operands), + text=text + ) + return ins diff --git a/plugins/xrefer/backend/utils.py b/plugins/xrefer/backend/utils.py index 66c34ae..a5342a7 100644 --- a/plugins/xrefer/backend/utils.py +++ b/plugins/xrefer/backend/utils.py @@ -1,8 +1,3 @@ -from typing import List - -from .base import Address, BackEnd - - def sample_path() -> str: """Return a sample path for the active backend.""" from . import get_current_backend @@ -79,7 +74,7 @@ def _dump_indirect_calls_ghidra(program): function_manager = program.getFunctionManager() listing = program.getListing() symbol_table = program.getSymbolTable() # Cache symbol table for performance - + for func in function_manager.getFunctions(True): instructions = listing.getInstructions(func.getBody(), True) for instr in instructions: @@ -102,51 +97,3 @@ def _dump_indirect_calls_ghidra(program): # Don't break here - continue processing remaining pcode operations return indirect_calls - - -class Mapping: - """A utility class for mapping addresses to symbols and vice-versa.""" - - def __init__(self, backend: BackEnd): - """ - Initializes the Mapping utility. - - Args: - backend: An instance of a backend (e.g., IDABackend or BNBackend). - """ - self._backend = backend - - def addr2sym(self, address: Address) -> List[str]: - """ - Resolves an address to a list of symbol names. - - Args: - address: The address to resolve. - - Returns: - A list of symbol names, which may be empty if no symbols are found. - """ - symbols = [] - func = self._backend.get_function_at(address) - if func: - symbols.append(func.name) - - # In some cases, a name might exist at an address without a function. - name = self._backend.get_name_at(address) - if name and name not in symbols: - symbols.append(name) - - return symbols - - def sym2addr(self, symbol: str) -> List[Address]: - """ - Resolves a symbol name to a list of addresses. - - Args: - symbol: The symbol name to resolve. - - Returns: - A list of addresses, which may be empty if the symbol is not found. - """ - address = self._backend.get_address_for_name(symbol) - return [address] if address else [] \ No newline at end of file diff --git a/plugins/xrefer/core/analyzer.py b/plugins/xrefer/core/analyzer.py index 2548144..0be9bf4 100644 --- a/plugins/xrefer/core/analyzer.py +++ b/plugins/xrefer/core/analyzer.py @@ -849,9 +849,9 @@ def _process_artifact_xrefs(self, idx: int, entity: Tuple, xrefs: Set[int], func if not func: continue - if func.is_thunk: - continue fn_start = func.start + if func.is_thunk and not self.is_simple_api_thunk(fn_start): + continue # Use orphan check that considers indirect xrefs is_orphan = self.is_orphan_function(fn_start) diff --git a/plugins/xrefer/core/helpers.py b/plugins/xrefer/core/helpers.py index 08c6c17..84dabd1 100644 --- a/plugins/xrefer/core/helpers.py +++ b/plugins/xrefer/core/helpers.py @@ -156,8 +156,8 @@ def fetch_repositories(search_string): return {"UNCATEGORIZED": {"path": "", "matched_lines": {}}} repositories = {} for hit in hits: - repo_name = hit["repo"]["raw"] - path = hit["path"]["raw"] + repo_name = hit["repo"] + path = hit["path"] snippet = hit["content"]["snippet"] matched_lines = parse_snippet(snippet) repositories[repo_name] = {"path": f"{repo_name}/{path}", "matched_lines": matched_lines} @@ -333,13 +333,11 @@ def filter_null_string(s: str, size: int) -> Tuple[str, int]: Returns: Tuple[str, int]: Filtered string and its actual length """ - ss, i = "", 0 - while i < size: - if s[i] == "\x00": - break - ss += s[i] - i += 1 - return ss, i + limit = min(size, len(s)) + for i, ch in enumerate(s[:limit]): + if ch == "\x00": + return s[:i], i + return s[:limit], limit def longest_line_length(s: Optional[str]) -> int: diff --git a/plugins/xrefer/core/settings.py b/plugins/xrefer/core/settings.py index ea48e66..b48b188 100644 --- a/plugins/xrefer/core/settings.py +++ b/plugins/xrefer/core/settings.py @@ -17,7 +17,6 @@ from time import time from typing import Any, Dict, List -from xrefer import backend from xrefer.core.helpers import log PathType = str @@ -49,7 +48,14 @@ def __init__(self): # IDB-specific settings - paths that can be customized per IDB self.idb_specific_paths = {"analysis", "capa", "trace", "xrefs"} - self.current_idb = backend.sample_path() + self._current_idb = None # Lazy + + @property + def current_idb(self) -> str: + if self._current_idb is None: + from ..backend import sample_path + self._current_idb = sample_path() + return self._current_idb def get_default_settings(self) -> Dict[str, Any]: """Get default settings dictionary with added display options.""" diff --git a/plugins/xrefer/lang/ida/rust.py b/plugins/xrefer/lang/ida/rust.py deleted file mode 100644 index 9c860f4..0000000 --- a/plugins/xrefer/lang/ida/rust.py +++ /dev/null @@ -1,714 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import typing -from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple - -import ida_bytes -import ida_funcs -import ida_ida -import ida_name -import ida_offset -import ida_segment -import ida_ua -import idaapi -import idautils -import idc -from tabulate import tabulate - -from xrefer.core.helpers import filter_null_string, log, normalize_path -from xrefer.gui.legacy.shim import BIN_SEARCH_FORWARD, SEARCH_DOWN, find_bytes, find_code, is_32bit -from xrefer.lang.lang_base import LanguageBase -from xrefer.lang.lang_default import LangDefault - -if typing.TYPE_CHECKING: - from xrefer.core.analyzer import XRefer - - -@dataclass -class RustStringInfo: - """ - Container for Rust string information. - - Attributes: - text (str): The actual string content - length (int): Length of the string in bytes - xrefs (Optional[List[int]]): List of cross-reference addresses to this string - """ - - text: str - length: int - xrefs: Optional[List[int]] = None - - -class RustStringParser: - """ - Parser for Rust string formats in binary. - - Handles parsing of various Rust string representations including those in - .data.rel.ro, .rdata sections, and strings referenced from text section. - - Attributes: - is_64bit (bool): Whether binary is 64-bit - sizeof_rust_string (int): Size of Rust string structure (16 or 8 bytes) - next_offset (int): Offset to next string field (8 or 4 bytes) - ror_num (int): Rotation number for validation (32 or 16) - poi (Callable): Function to read pointer-sized values - """ - - def __init__(self): - self.is_64bit = not is_32bit() - self.sizeof_rust_string = 16 if self.is_64bit else 8 - self.next_offset = 8 if self.is_64bit else 4 - self.ror_num = 32 if self.is_64bit else 16 - self.poi = ida_bytes.get_qword if self.is_64bit else ida_bytes.get_dword - - def get_data_rel_ro_strings(self) -> Dict[int, RustStringInfo]: - """ - Extract Rust strings from .data.rel.ro section. - - Scans the .data.rel.ro section for Rust string patterns, validates them, - and converts them to RustStringInfo objects. - - Returns: - Dict[int, RustStringInfo]: Dictionary mapping addresses to RustStringInfo objects - for all valid strings found in .data.rel.ro - """ - strings = {} - - data_rel_ro = ida_segment.get_segm_by_name(".data.rel.ro") - if not data_rel_ro: - return strings - - rdata = ida_segment.get_segm_by_name(".rdata") - if not rdata: - return strings - - curr_ea = data_rel_ro.start_ea - while curr_ea < data_rel_ro.end_ea: - ea_candidate = self.poi(curr_ea) - len_candidate = self.poi(curr_ea + self.next_offset) - - if self._is_valid_string(len_candidate, ea_candidate, rdata): - try: - s = ida_bytes.get_bytes(ea_candidate, len_candidate).decode("utf-8") - s, len_s = filter_null_string(s, len_candidate) - if len_s == len_candidate and ea_candidate not in strings: - ida_offset.op_plain_offset(curr_ea, 0, 0) - strings[ea_candidate] = RustStringInfo(s, len_candidate) - curr_ea += self.sizeof_rust_string - continue - except: - pass - curr_ea += 1 - - return strings - - def get_rdata_strings(self) -> Dict[int, RustStringInfo]: - """ - Extract Rust strings from .rdata section. - - Similar to get_data_rel_ro_strings but processes the .rdata section. - Uses the same validation and extraction logic for consistency. - - Returns: - Dict[int, RustStringInfo]: Dictionary mapping addresses to RustStringInfo objects - for all valid strings found in .rdata - """ - strings = {} - - rdata = ida_segment.get_segm_by_name(".rdata") - if not rdata: - return strings - - curr_ea = rdata.start_ea - while curr_ea < rdata.end_ea: - ea_candidate = self.poi(curr_ea) - len_candidate = self.poi(curr_ea + self.next_offset) - - if self._is_valid_string(len_candidate, ea_candidate, rdata): - try: - s = ida_bytes.get_bytes(ea_candidate, len_candidate).decode("utf-8") - s, len_s = filter_null_string(s, len_candidate) - if len_s == len_candidate and ea_candidate not in strings: - ida_offset.op_plain_offset(curr_ea, 0, 0) - strings[ea_candidate] = RustStringInfo(s, len_candidate) - curr_ea += self.sizeof_rust_string - continue - except: - pass - curr_ea += 1 - - return strings - - def get_text_strings(self) -> Dict[int, RustStringInfo]: - """ - Extract Rust strings referenced from .text section. - - Analyzes code in .text section to find string references and extracts - corresponding strings from .rdata section. More complex than other methods - as it needs to handle various instruction patterns. - - Returns: - Dict[int, RustStringInfo]: Dictionary mapping addresses to RustStringInfo objects - for strings referenced from code - """ - strings = {} - text = ida_segment.get_segm_by_name(".text") - rdata = ida_segment.get_segm_by_name(".rdata") - - if not text or not rdata: - return strings - - for func in idautils.Functions(text.start_ea, text.end_ea): - # Get function bounds - start = func - end = idc.find_func_end(start) - - # Collect all instruction addresses first - addrs = [] - inst = start - while inst < end: - addrs.append(inst) - inst = find_code(inst, SEARCH_DOWN) - - # Process instructions for string references - for i in range(len(addrs) - 2): # Need at least 2 more instructions - curr_addr = addrs[i] - mnem = idc.print_insn_mnem(curr_addr) - - # Only care about lea/mov instructions - if mnem not in ("lea", "mov"): - continue - - # Skip if already matches offset - if "off_" in idc.print_operand(curr_addr, 1): - continue - - ea_candidate = idc.get_operand_value(curr_addr, 1) - - # Must be in rdata segment - if not (rdata.start_ea <= ea_candidate <= rdata.end_ea): - continue - - ea_xref = curr_addr - - # Handle case where string already exists - if ea_candidate in strings: - self._update_existing_string(strings[ea_candidate], ea_xref) - continue - - # Look ahead for length in next instructions - len_found = False - len_candidate = 0 - - for j in range(i + 1, min(i + 3, len(addrs))): # Look at next 2 instructions max - if idc.print_insn_mnem(addrs[j]) == "mov": - if idc.get_operand_type(addrs[j], 1) == idc.o_imm: - len_candidate = idc.get_operand_value(addrs[j], 1) - len_found = True - break - - if not len_found or not (0 < len_candidate <= 0x200): - continue - - try: - s = ida_bytes.get_bytes(ea_candidate, len_candidate).decode("utf-8") - s, len_s = filter_null_string(s, len_candidate) - if len_s == len_candidate: - strings[ea_candidate] = RustStringInfo(s, len_candidate, [ea_xref]) - except: - continue - - return strings - - def _is_valid_string(self, length: int, addr: int, rdata: ida_segment.segment_t) -> bool: - """ - Validate a potential Rust string candidate. - - Args: - length (int): Length of potential string - addr (int): Address where string content is located - rdata (ida_segment.segment_t): .rdata section segment - - Returns: - bool: True if string appears valid based on Rust string criteria - """ - return (length >> self.ror_num) == 0 and 0 < length <= 0x200 and rdata.start_ea <= addr <= rdata.end_ea - - def _update_existing_string(self, string_info: RustStringInfo, xref: int) -> None: - """ - Update cross-references for an existing string. - - Args: - string_info (RustStringInfo): String information to update - xref (int): New cross-reference address to add - """ - - -class LangRust(LanguageBase): - """ - Rust-specific language analyzer. - - Handles detection and analysis of Rust binaries, including string extraction, - library references, and thread handling. - - Attributes: - strings (Optional[Dict[int, List[str]]]): Extracted strings - ep_annotation (Optional[str]): Entry point annotation - lib_refs (List[Any]): Library references - crate_columns (List[List[str]]): Crate names and versions - user_xrefs (List[Tuple[int, int]]): User-defined cross-references - """ - - def __init__(self): - super().__init__() - self.id = "lang_rust" - self.strings = None - self.ep_annotation = None - self.lib_refs = [] - self.crate_columns = [[], []] # [names], [versions] - self.user_xrefs = [] # Store thread xrefs here - - def initialize(self) -> None: - """Initialize Rust-specific data after language matching.""" - super().initialize() - self._process_if_rust() - - def lang_match(self) -> bool: - """Check if binary is Rust.""" - search_patterns = [ - "3A 3A 75 6E 77 72 61 70 28 29 60 20", # ::unwrap()` - "5C 2E 63 61 72 67 6F 5C", # \.cargo\ - "2F 2E 63 61 72 67 6F 2F", # /.cargo/ - "2F 63 61 72 67 6F 2F", # /cargo/ - "74 68 72 65 61 64 20 70 61 6E 69 63", # thread panic - ] - - for pattern in search_patterns: - if find_bytes(pattern, ida_ida.inf_get_min_ea(), ida_ida.inf_get_max_ea(), BIN_SEARCH_FORWARD, 16) != idc.BADADDR: - return True - - return False - - def _process_if_rust(self) -> None: - """ - Process binary as Rust if language detection matches. - - Performs Rust-specific analysis including user cross-references, - string processing, and entry point annotation if binary is detected as Rust. - """ - if not self.lang_match(): - return - - log("Rust compiled binary detected") - self.user_xrefs = self.get_user_xrefs() or [] - self._process_strings() - self.ep_annotation = self._get_ep_annotation() - - def _process_strings(self) -> None: - """ - Process Rust strings and library references. - - Combines strings from multiple sources (Rust string parser and IDA default strings), - processes library references, and updates internal string storage. - """ - # Get Rust-specific strings - parser = RustStringParser() - rust_strings = {} - rust_strings.update(parser.get_data_rel_ro_strings()) - rust_strings.update(parser.get_rdata_strings()) - rust_strings.update(parser.get_text_strings()) - - # Get default IDA strings - default_lang = LangDefault(backend=self.backend) - default_strings = default_lang.get_strings() - - # Merge both string sets - combined_strings = {} - combined_strings.update({ea: RustStringInfo(s[0], len(s[0])) for ea, s in default_strings.items()}) - combined_strings.update(rust_strings) - - # Process library references - self._process_lib_refs(combined_strings) - - # Create final string dict - self.strings = {ea: [info.text] if info.xrefs is None else [info.text, info.length, info.xrefs] for ea, info in combined_strings.items()} - - def _process_lib_refs(self, strings: Dict[int, RustStringInfo]) -> None: - """ - Process library references from string data. - - Analyzes strings to extract and process library references, - including version information and crate details. Particularly - important for Rust binary analysis. - - Args: - strings: Dictionary of string information to process - - Side Effects: - - Updates crate_columns with crate information - - Updates lib_refs with processed references - - Creates new entity entries for libraries - - Note: - Processes different reference types: - - Git repository references - - Crate version information - - Local library paths - - Source file references - """ - if not strings: - return - - # Define regex patterns - lib_patterns = { - "git": ( - r"(?:github\.com-[a-z0-9]+|crates\.io(?:-[a-z0-9]+)*)[\/\\]{1,2}" - r"([^\/\\]+)-(\d[^\/\\]+?)[\/\\]{1,2}.*?[\/\\]{1,2}" - r"([^\/\\]+?)[\/\\]+([^\/\\]+)\.rs" - ), - "git_simple": ( - r"(?:github\.com-[a-z0-9]+|crates\.io(?:-[a-z0-9]+)*)[\/\\]{1,2}" - r"([^\/\\]+)-(\d[^\/\\]+?)[\/\\]{1,2}[^\/\\]+?[\/\\]+([^\/\\]+)\.rs" - ), - "lib": (r"(?:library|src)[/\\]{1,2}([^/\\]+).*?[/\\]([^/\\]+?)[/\\]+([^/\\]+)\.rs"), - "lib_simple": (r"(?:library|src)[/\\]{1,2}([^/\\]+?)[/\\]+([^/\\]+)\.rs"), - } - - patterns = {k: re.compile(v) for k, v in lib_patterns.items()} - - # Track addresses to remove (we can't modify dict during iteration) - to_remove = set() - - for str_ea, string_info in strings.items(): - string_contents = string_info.text - - # Skip non-printable strings - if not all(c.isprintable() or c.isspace() for c in string_contents): - to_remove.add(str_ea) - continue - - string_contents = normalize_path(string_contents) - string_contents_lower = string_contents.lower() - matched = False - - # Process git references - if "github." in string_contents or "crates.io" in string_contents: - match = patterns["git"].search(string_contents) - if match: - self._handle_git_match(match, (1, 3, 4), str_ea) - matched = True - else: - match = patterns["git_simple"].search(string_contents) - if match: - self._handle_git_match(match, (1, 3), str_ea) - matched = True - - # Process library references - elif "library" in string_contents_lower or "src" in string_contents_lower: - match = patterns["lib"].search(string_contents) - if match: - self._handle_lib_match(match, (1, 2, 3), str_ea) - matched = True - else: - match = patterns["lib_simple"].search(string_contents) - if match: - self._handle_lib_match(match, (1, 2), str_ea) - matched = True - - # If we matched either git or lib reference, remove the string - if matched: - to_remove.add(str_ea) - - # Remove processed strings - for str_ea in to_remove: - del strings[str_ea] - - def _handle_git_match(self, match: re.Match, group_ids: Tuple[int, ...], str_ea: int) -> None: - """ - Handle git repository reference matches. - - Process matched git repository references to extract crate information - and add to library references. - - Args: - match (re.Match): Regex match object containing git reference - group_ids (Tuple[int, ...]): Tuple of group IDs to extract from match - str_ea (int): Address where the string was found - """ - crate_name = match.group(1) - version = match.group(2) - - if crate_name not in self.crate_columns[0]: - self.crate_columns[0].append(crate_name) - self.crate_columns[1].append(version) - - self._add_lib_ref(match, group_ids, str_ea) - - def _handle_lib_match(self, match: re.Match, group_ids: Tuple[int, ...], str_ea: int): - """Handle library reference match.""" - crate_name = match.group(1) - - if crate_name not in self.crate_columns[0]: - self.crate_columns[0].append(crate_name) - self.crate_columns[1].append("n/a") - - self._add_lib_ref(match, group_ids, str_ea) - - def _add_lib_ref(self, match: re.Match, group_ids: Tuple[int, ...], str_ea: int): - """Add library reference to lib_refs list.""" - # Get base token and details - tokens = [match.group(i).replace("-", "").replace("_", "") for i in group_ids] - lib_ref = f"{tokens[0]}::{tokens[1]}" - if len(tokens) == 3: - lib_ref = f"{lib_ref}::{tokens[2]}" - - self.lib_refs.append((str_ea, lib_ref, 1, tokens[0])) - - def _get_ep_annotation(self) -> str: - """Generate entry point annotation with crate information.""" - if not self.crate_columns[0]: - return "" - - headings = ["CRATE", "VERSION"] - columns = self.crate_columns - rows = [] - - max_col_len = max(len(col) for col in columns) - for i in range(max_col_len): - row = [col[i] if i < len(col) else "" for col in columns] - rows.append(row) - - annotation = f"{tabulate(rows, headers=headings, tablefmt='github')}\n\n" - annotation = f"@ xrefer - crate listing\n\n{annotation}" - return annotation - - def get_user_xrefs(self) -> Optional[List[Tuple[int, int]]]: - """ - Parse Rust thread objects and refs. - - Returns: - Optional[List[Tuple[int, int]]]: List of (call address, thread function address) pairs, - or None if not a Rust binary - """ - if not self.lang_match(): - return None - - result = [] - ptr_size = 4 if is_32bit() else 8 - - # Get CreateThread import - createthread_ea = idc.get_name_ea_simple("CreateThread") - if createthread_ea == idc.BADADDR: - return result - - # Find Rust's thread creation function - xrefs = idautils.XrefsTo(createthread_ea) - mw_createthread_xref = next(xrefs, None) - if not mw_createthread_xref: - return result - - # Get the function containing the CreateThread call - mw_createthread_ea = ida_funcs.get_func(mw_createthread_xref.frm).start_ea - - # Rename Rust's thread creation function - idaapi.set_name(mw_createthread_ea, "mw_createthread", idc.SN_NOCHECK) - - # Find all calls to Rust's thread creation function - threadcall_xrefs = idautils.XrefsTo(mw_createthread_ea) - - for xref in threadcall_xrefs: - # Check if reference is a call - if xref.type == idc.fl_CN: - ref = xref.frm - _ref = ref - - # Search 10 instructions back for thread function pointer - for _ in range(10): - thread_func = None - _ref = idc.prev_head(_ref) - - if "offset" in idc.generate_disasm_line(_ref, 0): - # Thread object structure: - # [0] vtable ptr - # [1] state - # [2] name - # [3] thread function ptr - pthread_func = idc.get_operand_value(_ref, 0) + ptr_size * 3 - - # Get actual thread function pointer - thread_func = ida_bytes.get_qword(pthread_func) if ptr_size == 8 else ida_bytes.get_dword(pthread_func) - - result.append((ref, thread_func)) - break - - return result - - def get_entry_point(self) -> Optional[int]: - """Get Rust program entry point.""" - # Only perform Rust-specific analysis if this is actually a Rust binary - if not self.lang_match(): - return super().get_entry_point() - - # Try explicit rust_main first - rust_main = idc.get_name_ea_simple("rust_main") - if rust_main != idc.BADADDR: - return rust_main - # Try main/_main and analyze for rust_main pattern - for main_name in ("main", "_main"): - main_ea = idc.get_name_ea_simple(main_name) - if main_ea != idc.BADADDR: - rust_main = self._find_rust_main(main_ea) - if rust_main: - return rust_main - - # Try finding main via __initenv analysis. just a hack for now, fix later. - main_ea = LanguageBase.fallback_cmain_detection(self.backend) - if main_ea: - rust_main = self._find_rust_main(main_ea) - if rust_main: - return rust_main - - # Fallback to default entry point finder if everything else fails - return super().get_entry_point() - - def _find_rust_main(self, main_ea: int) -> Optional[int]: - """Find rust_main by analyzing main function.""" - start = idc.get_func_attr(main_ea, idc.FUNCATTR_START) - end = idc.prev_addr(idc.get_func_attr(main_ea, idc.FUNCATTR_END)) - - is_64 = not is_32bit() # Use different variable name - - for addr in range(start, end): - refs = idautils.XrefsFrom(addr) - - for ref in refs: - if start <= ref.to <= end: - continue - - if ref.type != idc.fl_CN: # Not a direct call - target_func = idaapi.get_func(ref.to) - if not target_func or target_func.start_ea != ref.to: - continue - - # Look for call within next 8 instructions - call_found = self._find_call_after_ref(addr, 8, is_64) # Use new variable name - if call_found: - idaapi.set_name(ref.to, "rust_main", idc.SN_NOCHECK) - return ref.to - - return None - - def _find_call_after_ref(self, start_addr: int, max_instructions: int, is_64bit: bool) -> bool: - """Find call instruction after reference.""" - ins = ida_ua.insn_t() - ins_ea = start_addr - - for _ in range(max_instructions): - ins_ea = idc.next_head(ins_ea) - idaapi.decode_insn(ins, ins_ea) - - if not ins: - break - - if ins.itype in (idaapi.NN_call, idaapi.NN_callfi, idaapi.NN_callni): - target = idc.get_operand_value(ins_ea, 0) - target_flags = idc.get_func_flags(target) - - if target_flags < 0: # Handle pointer to function - target = (ida_bytes.get_qword if is_64bit else ida_bytes.get_dword)(target) - target_flags = idc.get_func_flags(target) - - # Skip imports, library functions and thunks - if not (target_flags & (idc.FUNC_LIB | idc.FUNC_STATIC | idc.FUNC_THUNK)): - return True - - return False - - def rename_functions(self, xrefer_obj: "XRefer") -> None: - """ - Rename functions based on their references. - - Args: - xrefer_obj: XreferenceLLM object containing global xrefs. - """ - # de-prioritize refs that have a chance of overlapping occurrence even in non-lined methods - depriori_list = ["std", "core", "alloc", "gimli", "object"] - selected_ref = None - name_index = {} - idaapi.show_wait_box("HIDECANCEL\nRenaming...") - - for func_ea, func_ref in xrefer_obj.global_xrefs.items(): - depriori_refs = set() - priori_refs = set() - - # only rename default function labels - orig_func_name = idc.get_func_name(func_ea) - if not orig_func_name.startswith("sub_"): - log(f"Renaming skipped: {orig_func_name}") - continue - - for xref_entity in func_ref[xrefer_obj.DIRECT_XREFS]["libs"]: - xref = xrefer_obj.entities[xref_entity][1] - if xref.split("::")[0] in depriori_list: - depriori_refs.add(xref) - else: - priori_refs.add(xref) - - if len(priori_refs): - selected_ref = self.find_common_denominator(list(priori_refs)) - - else: - selected_ref = None - - method_name_index = 0 - - if selected_ref: - if selected_ref not in name_index: - name_index[selected_ref] = method_name_index - else: - name_index[selected_ref] += 1 - method_name_index = name_index[selected_ref] - - orig_method_name = idc.get_func_name(func_ea) - method_name = f"{selected_ref}_{method_name_index}" - log(f"Renaming {orig_method_name} to {method_name}") - idaapi.set_name(func_ea, method_name, ida_name.SN_NOWARN | ida_name.SN_AUTO) - - idaapi.hide_wait_box() - - @staticmethod - def find_common_denominator(lib_refs: List[str]) -> Optional[str]: - """ - Find the common denominator among library references. - - Args: - lib_refs (List[str]): List of library references. - - Returns: - Optional[str]: Common denominator if found, None otherwise. - """ - if not lib_refs: - return None - - zipped_parts = zip(*[s.split("::") for s in lib_refs]) - common_parts = [parts[0] for parts in zipped_parts if all(p == parts[0] for p in parts)] - - if not common_parts: - return None - - return "::".join(common_parts) diff --git a/plugins/xrefer/lang/lang_base.py b/plugins/xrefer/lang/lang_base.py index f324c66..e8eeaad 100644 --- a/plugins/xrefer/lang/lang_base.py +++ b/plugins/xrefer/lang/lang_base.py @@ -15,7 +15,7 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional -from xrefer.backend import BackEnd, get_current_backend +from xrefer.backend import Address, BackEnd, SectionType, get_current_backend from xrefer.core.helpers import log @@ -124,19 +124,120 @@ def get_entry_point(self) -> Optional[int]: if address is not None: return address.value + # Prefer the program entry symbol if present (e.g., 'entry' in PE binaries) + entry_sym = self.backend.get_address_for_name("entry") + if entry_sym is not None: + ghidra_fallback = self._resolve_ghidra_user_entry(entry_sym) + if ghidra_fallback is not None: + return ghidra_fallback + return entry_sym.value + + ghidra_fallback = self._resolve_ghidra_user_entry(None) + if ghidra_fallback is not None: + return ghidra_fallback + # Fallback: try to find main function through common patterns fallback = self.fallback_cmain_detection(self.backend) if fallback: return fallback - else: - exports = self.backend.get_exports() - # If no main function found, return the first export as a last resort - first_export = next(exports, None) - if first_export: - return first_export[1].value + exports = self.backend.get_exports() + first_export = next(exports, None) + if first_export: + return first_export[1].value return None + def preferred_entry_name(self) -> str: + """Preferred symbol name for the resolved user entry point.""" + + return "main" + + def _resolve_ghidra_user_entry(self, entry_symbol: Optional[Address]) -> Optional[int]: + """Attempt to resolve the user entry point via Ghidra-specific heuristics.""" + + if self.backend.name != "ghidra": + return None + + entry_address_obj: Optional[Address] + if entry_symbol is not None: + entry_address_obj = entry_symbol if isinstance(entry_symbol, Address) else Address(int(entry_symbol)) + else: + resolved = self.backend.get_address_for_name("entry") + entry_address_obj = resolved if resolved is not None else None + + if entry_address_obj is None: + return None + + text_section = self.backend.get_section_by_name(".text") + if text_section is None: + return None + + exec_sections = [section for section in self.backend.get_sections() if getattr(section, "type", None) == SectionType.CODE] + + def _in_exec_section(addr: Address) -> bool: + if text_section.contains(addr): + if getattr(text_section, "type", None) == SectionType.CODE: + return True + for section in exec_sections: + if section.contains(addr): + return True + return False + + start = max(text_section.start.value, entry_address_obj.value - 0x400) + end = min(text_section.end.value, entry_address_obj.value + 0x2000) + + candidates: list[tuple[int, int]] = [] + try: + for inst_addr in self.backend.instructions(Address(start), Address(end)): + inst = self.backend.disassemble(inst_addr) + if inst.mnemonic != "call" or not inst.operands: + continue + target_val = inst.operands[0].value if inst.operands[0].value is not None else None + if target_val is None: + continue + if not _in_exec_section(target_val): + continue + if target_val == entry_address_obj: + continue + candidates.append((int(target_val), int(inst_addr))) + except Exception: + return None + + if not candidates: + return None + + # Prefer calls that land before the CRT stub; fall back to the lowest candidate. + before_stub = [pair for pair in candidates if pair[0] < entry_address_obj.value] + target_value, _ = min(before_stub or candidates, key=lambda item: item[0]) + + self._maybe_rename_entry_function(Address(target_value)) + log(f"Ghidra fallback resolved entry point at 0x{target_value:x}") + return target_value + + def _maybe_rename_entry_function(self, address: Address) -> None: + """Rename the resolved entry function if it uses a placeholder name.""" + + preferred = self.preferred_entry_name() + if not preferred: + return + + try: + entry_function = self.backend.get_function_at(address) + except Exception: + return + + if not entry_function: + return + + current_name = entry_function.name + normalized = current_name.lower() + placeholder_prefixes = ("fun_", "sub_", "entry") + if current_name and not normalized.startswith(placeholder_prefixes): + return + if current_name == preferred: + return + entry_function.name = preferred + def get_strings(self, filters: Optional[List[str]] = None) -> Dict[int, List[str]]: """ Extract strings from the binary with optional filtering. diff --git a/plugins/xrefer/lang/lang_registry.py b/plugins/xrefer/lang/lang_registry.py index 3f46b68..778e66d 100644 --- a/plugins/xrefer/lang/lang_registry.py +++ b/plugins/xrefer/lang/lang_registry.py @@ -18,56 +18,42 @@ import os from typing import Any, List, Type -from xrefer.backend import list_available_backends from xrefer.core.helpers import log from xrefer.lang.lang_base import LanguageBase from xrefer.lang.lang_default import LangDefault def get_language_modules() -> List[Type[LanguageBase]]: - """Get all available language module classes.""" - lang_classes = [] + """Discover all backend-neutral language module classes. + + Recursively scans the xrefer.lang package for files named 'lang_*.py' + (excluding base/registry/default) and loads classes deriving from LanguageBase. + """ + lang_classes: List[Type[LanguageBase]] = [] lang_dir = os.path.dirname(__file__) - # First, check for legacy lang_*.py files - lang_files = [f[:-3] for f in os.listdir(lang_dir) if f.startswith("lang_") and f.endswith(".py") and f not in ("lang_base.py", "lang_default.py", "lang_registry.py")] + exclude_files = {"lang_base.py", "lang_default.py", "lang_registry.py", "__init__.py"} - for module_name in lang_files: - try: - # Import module - module = importlib.import_module(f".{module_name}", package="xrefer.lang") + for root, _dirs, files in os.walk(lang_dir): + for filename in files: + if not (filename.startswith("lang_") and filename.endswith(".py")): + continue + if filename in exclude_files: + continue - # Find language class (subclass of LanguageBase) - for name, obj in inspect.getmembers(module, inspect.isclass): - if name == "LangDefault": - continue - if issubclass(obj, LanguageBase) and obj != LanguageBase: - lang_classes.append(obj) - except Exception as e: - log(f"[-] Error loading language module {module_name}: {e}") + rel_path = os.path.relpath(os.path.join(root, filename), lang_dir) + module_name = rel_path[:-3].replace(os.sep, ".") # strip .py and convert to package path - # Now check for new backend-organized structure using available backends - available_backends = list_available_backends() - for backend_name in available_backends.keys(): - backend_path = os.path.join(lang_dir, backend_name) - if os.path.isdir(backend_path): - # Scan for language modules in backend directory - for lang_file in os.listdir(backend_path): - if lang_file.endswith(".py") and not lang_file.startswith("__"): - module_name = lang_file[:-3] - log(f"Loading language module {backend_name}.{module_name}...") - try: - # Import module from backend subdirectory - module = importlib.import_module(f".{backend_name}.{module_name}", package="xrefer.lang") + try: + module = importlib.import_module(f".{module_name}", package="xrefer.lang") + for name, obj in inspect.getmembers(module, inspect.isclass): + if name == "LangDefault": + continue + if issubclass(obj, LanguageBase) and obj is not LanguageBase: + lang_classes.append(obj) + except Exception as e: + log(f"[-] Error loading language module {module_name}: {e}") - # Find language class (subclass of LanguageBase) - for name, obj in inspect.getmembers(module, inspect.isclass): - if name == "LangDefault": - continue - if issubclass(obj, LanguageBase) and obj != LanguageBase: - lang_classes.append(obj) - except Exception as e: - log(f"[-] Error loading language module {backend_name}.{module_name}: {e}") lang_str = ", ".join([cls.__name__ for cls in lang_classes]) log(f"Found language modules: {lang_str}") return lang_classes diff --git a/plugins/xrefer/lang/lang_rust.py b/plugins/xrefer/lang/lang_rust.py new file mode 100644 index 0000000..8a019f4 --- /dev/null +++ b/plugins/xrefer/lang/lang_rust.py @@ -0,0 +1,955 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import typing +from collections import deque +from dataclasses import dataclass +from typing import Deque, Dict, List, Optional, Tuple + +from tabulate import tabulate +from xrefer.backend import Address, FunctionType, OperandType, Section, SectionType, XrefType +from xrefer.core.helpers import filter_null_string, log, normalize_path +from xrefer.lang.lang_base import LanguageBase +from xrefer.lang.lang_default import LangDefault + +if typing.TYPE_CHECKING: + from xrefer.core.analyzer import BackEnd, XRefer + + +STATIC_DATA_SECTIONS: Tuple[str, ...] = (".rdata", ".data.rel.ro", ".rodata", ".data") +ADDRESS_TOKEN_RE = re.compile(r"(?:0x)?[0-9A-Fa-f]{5,}") +PLACEHOLDER_PREFIXES = ("sub_", "fun_", "lab_", "nullsub_", "thunk_") + + +def operand_address(inst, operand_index: int) -> Optional[int]: + """Best-effort resolve operand address across backends.""" + + operands = getattr(inst, "operands", None) + if not operands or operand_index >= len(operands): + return None + + operand = operands[operand_index] + + # Prefer explicit value if backend populated one + value = getattr(operand, "value", None) + if value is not None: + try: + resolved = int(value) + except Exception: + resolved = None + else: + if not isinstance(value, Address) or value.is_valid(): + return resolved + + texts = [] + op_text = getattr(operand, "text", None) + if op_text: + texts.append(op_text) + inst_text = getattr(inst, "text", None) + if inst_text: + texts.append(inst_text) + + for text in texts: + for match in ADDRESS_TOKEN_RE.finditer(text): + token = match.group(0) + # Skip small immediates (e.g., stack offsets) + cleaned = token[2:] if token.lower().startswith("0x") else token + if len(cleaned) < 5: + continue + try: + return int(token, 16) + except ValueError: + continue + return None + + +def address_in_sections(backend: "BackEnd", addr: Optional[int], section_names: Tuple[str, ...] = STATIC_DATA_SECTIONS) -> bool: + if addr is None or addr < 0: + return False + + try: + address_obj = Address(int(addr)) + except Exception: + return False + + for name in section_names: + section = backend.get_section_by_name(name) + if section and section.contains(address_obj): + return True + return False + + +def address_in_code_sections(backend: "BackEnd", addr: Optional[int]) -> bool: + if addr is None or addr < 0: + return False + + try: + address_obj = Address(int(addr)) + except Exception: + return False + + for section in backend.get_sections(): + section_type = getattr(section, "type", None) + if section_type == SectionType.CODE and section.contains(address_obj): + return True + return False + + +@dataclass +class RustStringInfo: + """ + Container for Rust string information. + + Attributes: + text (str): The actual string content + length (int): Length of the string in bytes + xrefs (Optional[List[int]]): List of cross-reference addresses to this string + """ + + text: str + length: int + xrefs: Optional[List[int]] = None + + +class RustStringParser: + """ + Parser for Rust string formats in binary. + + Handles parsing of various Rust string representations including those in + .data.rel.ro, .rdata sections, and strings referenced from text section. + + Attributes: + is_64bit (bool): Whether binary is 64-bit + sizeof_rust_string (int): Size of Rust string structure (16 or 8 bytes) + next_offset (int): Offset to next string field (8 or 4 bytes) + ror_num (int): Rotation number for validation (32 or 16) + poi (Callable): Function to read pointer-sized values + """ + + def __init__(self, backend): + self.backend: BackEnd = backend + self.ptr_size = self._guess_ptr_size() + self.is_64bit = self.ptr_size == 8 + self.sizeof_rust_string = 16 if self.is_64bit else 8 + self.next_offset = 8 if self.is_64bit else 4 + self.ror_num = 32 if self.is_64bit else 16 + + def get_data_rel_ro_strings(self) -> Dict[int, RustStringInfo]: + """ + Extract Rust strings from .data.rel.ro section. + + Scans the .data.rel.ro section for Rust string patterns, validates them, + and converts them to RustStringInfo objects. + + Returns: + Dict[int, RustStringInfo]: Dictionary mapping addresses to RustStringInfo objects + for all valid strings found in .data.rel.ro + """ + strings: Dict[int, RustStringInfo] = {} + + data_rel_ro = self.backend.get_section_by_name(".data.rel.ro") + if not data_rel_ro: + return strings + + rdata = self.backend.get_section_by_name(".rdata") + if not rdata: + return strings + + curr_ea = data_rel_ro.start.value + while curr_ea < data_rel_ro.end.value: + ea_candidate = self._read_ptr(curr_ea) + len_candidate = self._read_ptr(curr_ea + self.next_offset) + + if self._is_valid_string(len_candidate, ea_candidate, rdata): + try: + raw = self.backend.read_bytes(Address(ea_candidate), len_candidate) + if raw: + s = raw.decode("utf-8", errors="strict") + s, len_s = filter_null_string(s, len_candidate) + if len_s == len_candidate and ea_candidate not in strings: + strings[ea_candidate] = RustStringInfo(s, len_candidate) + curr_ea += self.sizeof_rust_string + continue + except UnicodeDecodeError: + pass + curr_ea += 1 + return strings + + def get_rdata_strings(self) -> Dict[int, RustStringInfo]: + """ + Extract Rust strings from .rdata section. + + Similar to get_data_rel_ro_strings but processes the .rdata section. + Uses the same validation and extraction logic for consistency. + + Returns: + Dict[int, RustStringInfo]: Dictionary mapping addresses to RustStringInfo objects + for all valid strings found in .rdata + """ + strings = {} + + rdata = self.backend.get_section_by_name(".rdata") + if not rdata: + return strings + + curr_ea = rdata.start.value + while curr_ea < rdata.end.value: + ea_candidate = self._read_ptr(curr_ea) + len_candidate = self._read_ptr(curr_ea + self.next_offset) + if ea_candidate is None or len_candidate is None: + curr_ea += 1 + continue + + if self._is_valid_string(len_candidate, ea_candidate, rdata): + try: + raw = self.backend.read_bytes(Address(ea_candidate), len_candidate) + if raw: + s = raw.decode("utf-8", errors="strict") + s, len_s = filter_null_string(s, len_candidate) + if len_s == len_candidate and ea_candidate not in strings: + strings[ea_candidate] = RustStringInfo(s, len_candidate) + curr_ea += self.sizeof_rust_string + continue + except UnicodeDecodeError: + pass + curr_ea += 1 + return strings + + def get_text_strings(self) -> Dict[int, RustStringInfo]: + """ + Extract Rust strings referenced from .text section. + + Analyzes code in .text section to find string references and extracts + corresponding strings from .rdata section. More complex than other methods + as it needs to handle various instruction patterns. + + Returns: + Dict[int, RustStringInfo]: Dictionary mapping addresses to RustStringInfo objects + for strings referenced from code + """ + strings = {} + text = self.backend.get_section_by_name(".text") + rdata = self.backend.get_section_by_name(".rdata") + if not text or not rdata: + return strings + + # for func in idautils.Functions(text.start_ea, text.end_ea): + for fn in self.backend.functions(): + if not text.contains(fn.start): + continue + + # Process instructions for string references + for bb in fn.basic_blocks: + for ins in self.backend.instructions(bb.start, bb.end): # TODO: This is ugly. Fix design in backend/. + curr_addr = ins.value + + inst = self.backend.disassemble(curr_addr) + # Only care about lea/mov instructions + if inst.mnemonic not in ("lea", "mov"): + continue + if not inst.operands or len(inst.operands) < 2: + continue + + # Skip if already matches offset + # TODO: This used to check for IDA's `off_` strings. We now resolve operand + # addresses directly so Binary Ninja/Ghidra paths share the same code. + ea_candidate = operand_address(inst, 1) + if ea_candidate is None: + continue + try: + ea_candidate_addr = Address(ea_candidate) + except Exception: + continue + + operand = inst.operands[1] + if operand.value is not None: + try: + assert Address(int(operand.value)) == ea_candidate_addr + except Exception: + pass + + # Must be in rdata segment (matches legacy behavior) + if not rdata.contains(ea_candidate_addr): + continue + + ea_xref = curr_addr + # Handle case where string already exists + if ea_candidate in strings: + self._update_existing_string(strings[ea_candidate], ea_xref) + continue + + # Look ahead for length in next instructions + len_found = False + len_candidate = 0 + + for cnt, j in enumerate(self.backend.instructions(curr_addr + 1, curr_addr + 20)): # TODO: Look at next 20 bytes max (design issue. 2 ins -> 20 bytes heuristics. ) + # just 2 ins + if cnt >= 2: + break + j = j.value + inst2 = self.backend.disassemble(j) + if inst2.mnemonic == "mov": + if inst2.operands and inst2.operands[1].value and inst2.operands[1].type == OperandType.IMMEDIATE: + len_candidate = inst2.operands[1].value + len_found = True + break + + if not len_found or not (0 < len_candidate <= 0x200): + continue + + try: + s = self.backend.read_bytes(ea_candidate_addr, len_candidate).decode("utf-8") + s, len_s = filter_null_string(s, len_candidate) + if len_s == len_candidate: + strings[ea_candidate] = RustStringInfo(s, len_candidate, [ea_xref]) + except: + continue + return strings + + def _is_valid_string(self, length: int, addr: int, rdata: "Section") -> bool: + """ + Validate a potential Rust string candidate. + + Args: + length (int): Length of potential string + addr (int): Address where string content is located + rdata (Section): .rdata section segment + + Returns: + bool: True if string appears valid based on Rust string criteria + """ + return (length >> self.ror_num) == 0 and 0 < length <= 0x200 and rdata.contains(Address(addr)) + + def _guess_ptr_size(self) -> int: + # Heuristic: if any section end exceeds 32-bit, assume 64-bit + try: + max_end = max(sec.end.value for sec in self.backend.get_sections()) + return 8 if max_end > 0xFFFFFFFF else 4 + except Exception: + # Safe default + return 8 + + def _read_ptr(self, ea: int) -> Optional[int]: + raw = self.backend.read_bytes(Address(ea), self.ptr_size) + if not raw or len(raw) != self.ptr_size: + return None + return int.from_bytes(raw, byteorder="little", signed=False) + + def _update_existing_string(self, string_info: RustStringInfo, xref: int) -> None: + """ + Update cross-references for an existing string. + + Args: + string_info (RustStringInfo): String information to update + xref (int): New cross-reference address to add + """ + + +class LangRust(LanguageBase): + """ + Rust-specific language analyzer. + + Handles detection and analysis of Rust binaries, including string extraction, + library references, and thread handling. + + Attributes: + strings (Optional[Dict[int, List[str]]]): Extracted strings + ep_annotation (Optional[str]): Entry point annotation + lib_refs (List[Any]): Library references + crate_columns (List[List[str]]): Crate names and versions + user_xrefs (List[Tuple[int, int]]): User-defined cross-references + """ + + def __init__(self): + super().__init__() + self.id = "lang_rust" + self.strings = None + self.ep_annotation = None + self.lib_refs = [] + self.crate_columns = [[], []] # [names], [versions] + self.user_xrefs = [] # Store thread xrefs here + + def initialize(self) -> None: + """Initialize Rust-specific data after language matching.""" + super().initialize() + self._process_if_rust() + + def lang_match(self) -> bool: + string_markers = [ + "::unwrap()", + ".cargo", + "/cargo/", + "thread panic", + ] + hits = 0 + for s in self.backend.strings(min_length=5): + sc = s.content + if any(tok in sc for tok in string_markers): + hits += 1 + if hits >= 2: + return True + return False + + def _process_if_rust(self) -> None: + """ + Process binary as Rust if language detection matches. + + Performs Rust-specific analysis including user cross-references, + string processing, and entry point annotation if binary is detected as Rust. + """ + if not self.lang_match(): + return + + log("Rust compiled binary detected") + self.user_xrefs = self.get_user_xrefs() or [] + self._process_strings() + self._ensure_rust_entry_alias() + self.ep_annotation = self._get_ep_annotation() + + def _process_strings(self) -> None: + """ + Process Rust strings and library references. + + Combines strings from multiple sources (Rust string parser and IDA default strings), + processes library references, and updates internal string storage. + """ + # Get Rust-specific strings + parser = RustStringParser(self.backend) + rust_strings = {} + rust_strings.update(parser.get_data_rel_ro_strings()) + rust_strings.update(parser.get_rdata_strings()) + rust_strings.update(parser.get_text_strings()) + + # Get default IDA strings + default_lang = LangDefault(backend=self.backend) + default_strings = default_lang.get_strings() + + # Merge both string sets + combined_strings = {} + combined_strings.update({ea: RustStringInfo(s[0], len(s[0])) for ea, s in default_strings.items()}) + combined_strings.update(rust_strings) + + # Process library references + self._process_lib_refs(combined_strings) + + # Create final string dict + self.strings = {ea: [info.text] if info.xrefs is None else [info.text, info.length, info.xrefs] for ea, info in combined_strings.items()} + + def _process_lib_refs(self, strings: Dict[int, RustStringInfo]) -> None: + """ + Process library references from string data. + + Analyzes strings to extract and process library references, + including version information and crate details. Particularly + important for Rust binary analysis. + + Args: + strings: Dictionary of string information to process + + Side Effects: + - Updates crate_columns with crate information + - Updates lib_refs with processed references + - Creates new entity entries for libraries + + Note: + Processes different reference types: + - Git repository references + - Crate version information + - Local library paths + - Source file references + """ + if not strings: + return + + # Define regex patterns + lib_patterns = { + "git": ( + r"(?:github\.com-[a-z0-9]+|crates\.io(?:-[a-z0-9]+)*)[\/\\]{1,2}" + r"([^\/\\]+)-(\d[^\/\\]+?)[\/\\]{1,2}.*?[\/\\]{1,2}" + r"([^\/\\]+?)[\/\\]+([^\/\\]+)\.rs" + ), + "git_simple": ( + r"(?:github\.com-[a-z0-9]+|crates\.io(?:-[a-z0-9]+)*)[\/\\]{1,2}" + r"([^\/\\]+)-(\d[^\/\\]+?)[\/\\]{1,2}[^\/\\]+?[\/\\]+([^\/\\]+)\.rs" + ), + "lib": (r"(?:library|src)[/\\]{1,2}([^/\\]+).*?[/\\]([^/\\]+?)[/\\]+([^/\\]+)\.rs"), + "lib_simple": (r"(?:library|src)[/\\]{1,2}([^/\\]+?)[/\\]+([^/\\]+)\.rs"), + } + + patterns = {k: re.compile(v) for k, v in lib_patterns.items()} + + # Track addresses to remove (we can't modify dict during iteration) + to_remove = set() + + for str_ea, string_info in strings.items(): + string_contents = string_info.text + + # Skip non-printable strings + if not all(c.isprintable() or c.isspace() for c in string_contents): + to_remove.add(str_ea) + continue + + string_contents = normalize_path(string_contents) + string_contents_lower = string_contents.lower() + matched = False + + # Process git references + if "github." in string_contents or "crates.io" in string_contents: + match = patterns["git"].search(string_contents) + if match: + self._handle_git_match(match, (1, 3, 4), str_ea) + matched = True + else: + match = patterns["git_simple"].search(string_contents) + if match: + self._handle_git_match(match, (1, 3), str_ea) + matched = True + + # Process library references + elif "library" in string_contents_lower or "src" in string_contents_lower: + match = patterns["lib"].search(string_contents) + if match: + self._handle_lib_match(match, (1, 2, 3), str_ea) + matched = True + else: + match = patterns["lib_simple"].search(string_contents) + if match: + self._handle_lib_match(match, (1, 2), str_ea) + matched = True + + # If we matched either git or lib reference, remove the string + if matched: + to_remove.add(str_ea) + + # Remove processed strings + for str_ea in to_remove: + del strings[str_ea] + + def _handle_git_match(self, match: re.Match, group_ids: Tuple[int, ...], str_ea: int) -> None: + """ + Handle git repository reference matches. + + Process matched git repository references to extract crate information + and add to library references. + + Args: + match (re.Match): Regex match object containing git reference + group_ids (Tuple[int, ...]): Tuple of group IDs to extract from match + str_ea (int): Address where the string was found + """ + crate_name = match.group(1) + version = match.group(2) + + if crate_name not in self.crate_columns[0]: + self.crate_columns[0].append(crate_name) + self.crate_columns[1].append(version) + + self._add_lib_ref(match, group_ids, str_ea) + + def _handle_lib_match(self, match: re.Match, group_ids: Tuple[int, ...], str_ea: int): + """Handle library reference match.""" + crate_name = match.group(1) + + if crate_name not in self.crate_columns[0]: + self.crate_columns[0].append(crate_name) + self.crate_columns[1].append("n/a") + + self._add_lib_ref(match, group_ids, str_ea) + + def _add_lib_ref(self, match: re.Match, group_ids: Tuple[int, ...], str_ea: int): + """Add library reference to lib_refs list.""" + # Get base token and details + tokens = [match.group(i).replace("-", "").replace("_", "") for i in group_ids] + lib_ref = f"{tokens[0]}::{tokens[1]}" + if len(tokens) == 3: + lib_ref = f"{lib_ref}::{tokens[2]}" + + self.lib_refs.append((str_ea, lib_ref, 1, tokens[0])) + + def _get_ep_annotation(self) -> str: + """Generate entry point annotation with crate information.""" + if not self.crate_columns[0]: + return "" + + headings = ["CRATE", "VERSION"] + columns = self.crate_columns + rows = [] + + max_col_len = max(len(col) for col in columns) + for i in range(max_col_len): + row = [col[i] if i < len(col) else "" for col in columns] + rows.append(row) + + annotation = f"{tabulate(rows, headers=headings, tablefmt='github')}\n\n" + annotation = f"@ xrefer - crate listing\n\n{annotation}" + return annotation + + def _ensure_rust_entry_alias(self) -> None: + """Ensure the real Rust entry point is labeled `rust_main` when possible.""" + + if not self.entry_point or self.backend.name != "ghidra": + return + + text_section = self.backend.get_section_by_name(".text") + if not text_section: + return + + try: + main_function = self.backend.get_function_at(Address(self.entry_point)) + except Exception: + return + + if not main_function: + return + + for bb in main_function.basic_blocks: + for inst_addr in self.backend.instructions(bb.start, bb.end): + inst = self.backend.disassemble(inst_addr) + if inst.mnemonic != "lea" or len(inst.operands) < 2: + continue + + candidate = operand_address(inst, 1) + if candidate is None: + continue + + candidate_addr = Address(candidate) + if not text_section.contains(candidate_addr): + continue + fn = self.backend.get_function_at(candidate_addr) + + if not fn: + continue + + placeholder_prefixes = ("fun_", "sub_", "entry", "nullsub_", "se_func") + if fn.name.lower().startswith(placeholder_prefixes): + fn.name = "rust_main" + return + + def get_user_xrefs(self) -> Optional[List[Tuple[int, int]]]: + """ + Parse Rust thread objects and refs. + + Returns: + Optional[List[Tuple[int, int]]]: List of (call address, thread function address) pairs, + or None if not a Rust binary + """ + if not self.lang_match(): + return None + + result: List[Tuple[int, int]] = [] + ptr_size = self._ptr_size() + + # Get CreateThread import + createthread_addr: Optional[Address] = None + for addr, full, module in self.backend.get_imports(): + # Normalize to "kernel32.createthread" + if full.lower().endswith(".createthread") and module.lower() in ("kernel32", "unknown"): + createthread_addr = addr + break + + if not createthread_addr: + return result + first_ref = next(iter(self.backend.get_xrefs_to(createthread_addr)), None) + if not first_ref: + return result + + # Get the function containing the CreateThread call + wrapper_func = self.backend.get_function_at(first_ref.source) + if not wrapper_func: + return result + # Rename Rust's thread creation function + wrapper_func.name = "mw_createthread" + # Find all calls to Rust's thread creation function + for xref in self.backend.get_xrefs_to(wrapper_func.start): + # Check if reference is a call + if xref.type == XrefType.CALL: + ref = xref.source + _ref = ref.value + + # Search 10 instructions back for thread function pointer + caller_fn = self.backend.get_function_containing(_ref) + caller_bb = [bb for bb in caller_fn.basic_blocks if bb.contains(_ref)] + assert len(caller_bb) == 1, "There are cases where #bb>=2, but ignore for now. open issue when this is the case" + ins = list(self.backend.instructions(caller_bb[0].start, _ref)) + for prev_ea in reversed(ins[-10:]): + _ref = prev_ea + thread_func = None + + disasm = self.backend.disassemble(_ref) + base_pointer = self._extract_thread_object_base(disasm) + if base_pointer is not None: + # Thread object structure: + # [0] vtable ptr + # [1] state + # [2] name + # [3] thread function ptr + try: + pthread_func_addr_int = base_pointer + ptr_size * 3 + pthread_func_addr = Address(pthread_func_addr_int) + except Exception: + continue + + thread_func = self._read_ptr(pthread_func_addr, ptr_size) + if thread_func is None: + continue + + # Double-check the dereference for parity with the legacy code. + func_ptr = self._read_ptr(pthread_func_addr, ptr_size) + if func_ptr != thread_func: + continue + + result.append((ref.value, thread_func)) + break + return result + + def _extract_thread_object_base(self, inst) -> Optional[int]: + """ + Backend-neutral helper replacing the old IDA-only "offset" text check. + + Attempts operand 1 first (Binja/Ghidra place the address here), then + operand 0 (IDA sometimes emits OFFSET there). Only addresses that fall + inside read-only data sections are considered valid. + """ + + for candidate_idx in (1, 0): + addr = operand_address(inst, candidate_idx) + if addr is None: + continue + if address_in_sections(self.backend, addr): + return addr + return None + + def get_entry_point(self) -> Optional[int]: + """Get Rust program entry point.""" + # Only perform Rust-specific analysis if this is actually a Rust binary + if not self.lang_match(): + return super().get_entry_point() + + base_entry = super().get_entry_point() + + # Try explicit rust_main first + rust_main = self.backend.get_address_for_name("rust_main") + if rust_main: + return rust_main.value + + # Try main/_main and analyze for rust_main pattern (after super() has + # already triggered Ghidra's entry rename heuristics). + for main_name in ("main", "_main"): + main_ea = self.backend.get_address_for_name(main_name) + if main_ea: + candidate = self._find_rust_main(main_ea) + if candidate: + return candidate + + if base_entry: + candidate = self._find_rust_main(base_entry) + if candidate: + return candidate + + # Fallback: probe CRT init pattern directly if we still have nothing. + fallback_main = LanguageBase.fallback_cmain_detection(self.backend) + if fallback_main: + candidate = self._find_rust_main(fallback_main) + if candidate: + return candidate + + # Last resort: hand back the base entry (keeps parity with other backends). + return base_entry + + def _find_rust_main(self, main_addr: int) -> Optional[int]: + """Find rust_main by analyzing main function.""" + # main_ea = main_addr + if isinstance(main_addr, int): + main_addr = Address(main_addr) + fn = self.backend.get_function_at(main_addr) + # TODO: In ghidra, this value is wrong cause the `main` isn't automatically set (i.e. we need to manually set `main` from `__scrt_common_main_seh`) + + # start = fn.start + # # end = idc.prev_addr(idc.get_func_attr(main_ea, idc.FUNCATTR_END)) + + # is_64 = not is_32bit() # Use different variable name + block_ranges = sorted( + ((bb.start, bb.end) for bb in fn.basic_blocks), + key=lambda pair: pair[0].value, + ) + instruction_window: Deque[Address] = deque(maxlen=12) + + for start, end in block_ranges: + for ins in self.backend.instructions(start, end): + instruction_window.append(ins) + inst = None + try: + inst = self.backend.disassemble(ins) + except Exception: + pass + + inst_mnemonic = getattr(inst, "mnemonic", "") + inst_is_call = inst and inst_mnemonic.lower() == "call" + + for xr in self.backend.get_xrefs_from(ins): + if fn.contains(xr.target): + continue + + is_call = xr.type == XrefType.CALL or inst_is_call + if not is_call: + continue + + wrapper_fn = self.backend.get_function_at(xr.target) + if not wrapper_fn: + continue + if wrapper_fn.start == fn.start: + continue + + candidate_addr = self._extract_rust_closure_address(instruction_window, xr.target.value) + if candidate_addr is None: + candidate_addr = xr.target.value + + candidate_fn = self.backend.get_function_at(Address(candidate_addr)) + if not candidate_fn: + self._define_function_if_absent(candidate_addr) + candidate_fn = self.backend.get_function_at(Address(candidate_addr)) + if not candidate_fn: + if candidate_addr != xr.target.value: + continue + candidate_fn = wrapper_fn + + if candidate_fn.type in (FunctionType.IMPORT, FunctionType.LIBRARY, FunctionType.THUNK, FunctionType.EXPORT, FunctionType.EXTERN): + continue + + current_name = (candidate_fn.name or "").lower() + if current_name and not current_name.startswith(("fun_", "sub_", "replace_me_", "lab_")): + return candidate_fn.start.value + + try: + candidate_fn.name = "rust_main" + except Exception: + pass + return candidate_fn.start.value + return None + + def _extract_rust_closure_address(self, instruction_window: Deque[Address], fallback_target: Optional[int]) -> Optional[int]: + """Scan preceding instructions for a code pointer stored before the wrapper call.""" + + if not instruction_window: + return None + + window_without_call = list(instruction_window)[:-1] + + for ins_addr in reversed(window_without_call): + try: + inst = self.backend.disassemble(ins_addr) + except Exception: + continue + + for idx, _ in enumerate(getattr(inst, "operands", ())): + addr = operand_address(inst, idx) + if addr is None: + continue + if fallback_target is not None and addr == fallback_target: + continue + if address_in_code_sections(self.backend, addr): + return addr + + return None + + def _define_function_if_absent(self, addr: int) -> None: + """Ensure a function exists at `addr` when backends defer closure emission. + HACK: Ghidra is the only backend as of now that requires this (i.e. ) + """ + + if self.backend.name != "ghidra": + return + + program = self.backend._get_actual_program() # type: ignore[attr-defined] + addr_factory = program.getAddressFactory() + gh_addr = addr_factory.getAddress(f"{addr:x}") + from ghidra.program.flatapi import FlatProgramAPI + + flat_api = FlatProgramAPI(program) + existing = program.getFunctionManager().getFunctionAt(gh_addr) + if existing is None: + flat_api.createFunction(gh_addr, f"FUN_{addr:x}") + + def rename_functions(self, xrefer_obj: "XRefer") -> None: + """ + Rename functions based on their references. + + Args: + xrefer_obj: XreferenceLLM object containing global xrefs. + """ + # de-prioritize refs that have a chance of overlapping occurrence even in non-lined methods + depriori_list = ["std", "core", "alloc", "gimli", "object"] + selected_ref = None + name_index = {} + for func_ea, func_ref in xrefer_obj.global_xrefs.items(): + depriori_refs = set() + priori_refs = set() + fn = self.backend.get_function_at(Address(func_ea)) + if not fn: + continue + + # # only rename default function labels + if not any(fn.name.startswith(x) for x in ("sub_", "FUN_")): # TODO: limit the logic depending on the backend. In practice, no one manually names a function like FUN_addr, so just ignore for now. + # TODO: expose property in backend/ to detect if auto-named + log(f"Renaming skipped: {fn.name}") + if fn.type in (FunctionType.IMPORT, FunctionType.LIBRARY, FunctionType.THUNK, FunctionType.EXPORT, FunctionType.EXTERN): + log(f"Renaming skipped (type): {fn.name}") + continue + + for xref_entity in func_ref[xrefer_obj.DIRECT_XREFS]["libs"]: + xref = xrefer_obj.entities[xref_entity][1] + if xref.split("::")[0] in depriori_list: + depriori_refs.add(xref) + else: + priori_refs.add(xref) + + selected_ref = self.find_common_denominator(list(priori_refs)) if len(priori_refs) else None + if not selected_ref: + continue + + idx = name_index.get(selected_ref, 0) + name_index[selected_ref] = idx + 1 + method_name = f"{selected_ref}_{name_index[selected_ref]}" + log(f"Renaming {fn.name} to {method_name}") + fn.name = method_name + + @staticmethod + def find_common_denominator(lib_refs: List[str]) -> Optional[str]: + """ + Find the common denominator among library references. + + Args: + lib_refs (List[str]): List of library references. + + Returns: + Optional[str]: Common denominator if found, None otherwise. + """ + if not lib_refs: + return None + zipped_parts = zip(*[s.split("::") for s in lib_refs]) + common_parts = [parts[0] for parts in zipped_parts if all(p == parts[0] for p in parts)] + if not common_parts: + return None + return "::".join(common_parts) + + def _ptr_size(self) -> int: + max_end = max(sec.end.value for sec in self.backend.get_sections()) + return 8 if max_end > 0xFFFFFFFF else 4 + + def _read_ptr(self, addr: Address, size: int) -> Optional[int]: + raw = self.backend.read_bytes(addr, size) + if not raw or len(raw) != size: + return None + return int.from_bytes(raw, "little") diff --git a/scripts/test.py b/scripts/test.py new file mode 100644 index 0000000..f1bfbd9 --- /dev/null +++ b/scripts/test.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +Unified XRefer testing script +""" + +import argparse +import os +import sys +import traceback +from importlib.util import find_spec +from pathlib import Path +from typing import Any, Literal + +BACKEND = Literal["ida", "binaryninja", "ghidra"] + +sys.stdout.reconfigure(line_buffering=True) +sys.stderr.reconfigure(line_buffering=True) + +_prjdir = os.environ.get("PROJECT") +if _prjdir is None: + raise OSError("set PROJECT to the plugins dir") +PROJECT_DIR = Path(_prjdir) +assert PROJECT_DIR.exists(), f"PROJECT_DIR does not exist: {PROJECT_DIR}" +sys.path.insert(0, str(PROJECT_DIR.absolute())) + +pkg_path = Path(find_spec("xrefer").origin).resolve().parent + + +class BackendNotAvailableError(Exception): + """Raised when a requested backend is not available.""" + + +def detect_available_backends() -> list[str]: + """Detect which backends are available on the system.""" + backends = [] + for spec, name in [("idapro", "ida"), ("binaryninja", "binaryninja"), ("pyghidra", "ghidra")]: + if find_spec(spec) is not None: + backends.append(name) + return backends + + +def get_backend_extensions(backend: BACKEND) -> list[str]: + """Get file extensions associated with each backend.""" + # Note: Ghidra uses project directories, not simple file extensions + extensions = { + "ida": [".id0", ".id1", ".id2", ".nam", ".til", ".i64"], + "binaryninja": [".bndb"], + # Ghidra handled specially in cleanup_previous_analysis + } + return extensions.get(backend, []) + + +def cleanup_previous_analysis(file_path: Path, backend: str, force: bool = False) -> None: + """Clean up previous analysis artifacts for the specified backend.""" + if not force: + return + + if backend == "ghidra": + # Common pyghidra project layouts to remove: + # - _ghidra (directory) + # Also clear any stale .xrefer next to the binary path + import shutil + + candidates = [ + file_path.parent / f"{file_path.name}_ghidra", + file_path.parent / f"{file_path.stem}.rep", + file_path.with_suffix(".xrefer"), + ] + + for path in candidates: + try: + if path.exists(): + if path.is_dir(): + print(f"[+] Removing previous Ghidra project: {path}") + shutil.rmtree(path) + else: + print(f"[+] Removing previous artifact: {path}") + path.unlink() + except Exception as e: + print(f"[!] Warning: Failed to remove {path}: {e}") + else: + # Generic cleanup via known extensions + extensions = get_backend_extensions(backend) + for ext in extensions: + artifact_file = file_path.with_suffix(ext) + if artifact_file.exists(): + print(f"[+] Removing previous artifact: {artifact_file}") + artifact_file.unlink() + # Remove .xrefer output files + xrefer_file = file_path.with_suffix(".xrefer") + if xrefer_file.exists(): + print(f"[+] Removing previous XRefer output: {xrefer_file}") + xrefer_file.unlink() + + +def setup_ida_backend(): + """Set up IDA Pro backend requirements.""" + try: + try: + import idapro + except ImportError: + raise ImportError("Please ensure IDA Pro is installed and the idapro module is available.") + import ida_undo + from PyQt5.QtWidgets import QApplication + except ImportError as e: + raise BackendNotAvailableError(f"IDA Pro backend not available: {e}") + + def ensure_qapplication(): + """Ensures a QApplication instance exists.""" + if QApplication.instance(): + return QApplication.instance() + app = QApplication(sys.argv if sys.argv else ["idaclixrefer_headless"]) + return app + + app = ensure_qapplication() + return {"app": app, "ida_undo": ida_undo, "idapro": idapro} + + +def setup_binaryninja_backend(): + """Set up Binary Ninja backend requirements.""" + try: + import binaryninja as bn + except ImportError as e: + raise BackendNotAvailableError(f"Binary Ninja backend not available: {e}") + + import xrefer.backend as backend_module + from xrefer.backend.factory import BackendManager + + backend_module.Backend = None # Force re-initialization + + return {"bn": bn, "backend_module": backend_module, "BackendManager": BackendManager} + + +def setup_ghidra_backend(): + """Set up Ghidra backend requirements.""" + try: + import pyghidra + except ImportError as e: + raise BackendNotAvailableError(f"Ghidra backend not available: {e}") + + import xrefer.backend as backend_module + from xrefer.backend.factory import BackendManager + + backend_module.Backend = None # Force re-initialization + + return {"pyghidra": pyghidra, "backend_module": backend_module, "BackendManager": BackendManager} + + +def analysis_ida(filepath: Path, modules: dict[str, Any] | None = None): + """Run XRefer analysis with IDA Pro backend.""" + import idapro + + idapro.get_library_version() + + from xrefer.core.analyzer import XRefer + + try: + xrefer_obj = XRefer(auto_analyze=True) # This automatically calls load_analysis() + print(f"[+] XRefer analysis complete, results saved to {xrefer_obj.settings['paths']['analysis']}") + return xrefer_obj + except Exception as e: + print(f"[x] Analysis failed: {e}") + traceback.print_exc() + raise + + +def analysis_binaryninja(bv, modules: dict[str, Any] | None = None): + """Run XRefer analysis with Binary Ninja backend.""" + backend_module = modules["backend_module"] + BackendManager = modules["BackendManager"] + + backend_manager = BackendManager() + backend = backend_manager.create_backend("binaryninja", bv=bv) + backend_manager.set_active_backend(backend) + backend_module.Backend = backend + from xrefer.core.analyzer import XRefer + + xrefer_obj = XRefer(auto_analyze=True) # This automatically calls load_analysis() + print(f"[+] XRefer analysis complete, results saved to {xrefer_obj.settings['paths']['analysis']}") + return xrefer_obj + + +def analysis_ghidra(_filepath: Path, modules: dict[str, Any] | None = None): + """Run XRefer analysis with Ghidra backend.""" + backend_module = modules["backend_module"] + BackendManager = modules["BackendManager"] + + backend_manager = BackendManager() + backend = backend_manager.create_backend("ghidra") + backend_manager.set_active_backend(backend) + + backend_module.Backend = backend + from xrefer.core.analyzer import XRefer + + xrefer_obj = XRefer(auto_analyze=True) # This automatically calls load_analysis() + print(f"[+] XRefer analysis complete, results saved to {xrefer_obj.settings['paths']['analysis']}") + return xrefer_obj + + +def _analyze_ida(file_path: Path, auto_analysis: bool = True, save_changes: bool = False, force_analysis: bool = False) -> None: + """Analyze with IDA Pro backend.""" + modules = setup_ida_backend() + idapro = modules["idapro"] + + cleanup_previous_analysis(file_path, "ida", force_analysis) + + project_exists = any(file_path.with_suffix(ext).exists() for ext in [".id0", ".i64"]) + if project_exists and not force_analysis: + print(f"[+] Opening existing IDA project for {file_path}") + else: + print(f"[+] Creating new IDA project for {file_path}") + + try: + idapro.open_database(str(file_path), run_auto_analysis=auto_analysis) + analysis_ida(file_path, modules=modules) + finally: + idapro.close_database(save=save_changes) + + +def _analyze_binaryninja(file_path: Path, auto_analysis: bool = True, save_changes: bool = False, force_analysis: bool = False) -> None: + """Analyze with Binary Ninja backend.""" + import binaryninja + + modules = setup_binaryninja_backend() + bn: "binaryninja" = modules["bn"] + + cleanup_previous_analysis(file_path, "binaryninja", force_analysis) + # Determine BN database path alongside the input file + bndb_path = file_path.with_suffix(".bndb") + print(f"[+] Loading binary file: {file_path}") + bn.disable_default_log() + bv = bn.load(str(file_path), options={"analysis.mode": "full" if auto_analysis else "basic"}) + + if bv is None: + raise Exception(f"Failed to load binary: {file_path}") + + try: + if auto_analysis and not bndb_path.exists(): + print("[+] Waiting for auto-analysis...") + bv.update_analysis_and_wait() + + if save_changes and not bndb_path.exists(): + print(f"[+] Creating Binary Ninja database: {bndb_path}") + bv.create_database(str(bndb_path)) + + # Save snapshot before analysis + if save_changes: + bv.save_auto_snapshot() + + analysis_binaryninja(bv, modules=modules) + + if save_changes: + bv.save_auto_snapshot() + print(f"[+] Saved Binary Ninja database: {bndb_path}") + + finally: + bv.file.close() + + +def _analyze_ghidra(file_path: Path, auto_analysis: bool = True, save_changes: bool = False, force_analysis: bool = False) -> None: + """Analyze with Ghidra backend.""" + modules = setup_ghidra_backend() + pyghidra = modules["pyghidra"] + + cleanup_previous_analysis(file_path, "ghidra", force_analysis) + + pyghidra.start() + + with pyghidra.open_program(str(file_path), analyze=auto_analysis) as flat_api: + from xrefer.backend.factory import backend_manager + + ghidra_backend = backend_manager.create_backend("ghidra", program=flat_api.getCurrentProgram()) + backend_manager.set_active_backend(ghidra_backend) + analysis_ghidra(file_path, modules=modules) + if save_changes: + print("[+] Saving Ghidra project...") + try: + program = flat_api.getCurrentProgram() + # End any active transaction before saving + if program.hasActiveTrxs(): + program.endTrx() + flat_api.saveProgram(program) + except Exception as save_error: + print(f"[!] Save failed: {save_error}") + # Continue without failing the analysis + + +def cli(): + """Command line interface.""" + available_backends = detect_available_backends() + + parser = argparse.ArgumentParser(description="Unified XRefer testing script for multiple backends", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__) + + parser.add_argument("file", type=Path, help="Path to the file to analyze") + parser.add_argument("--backend", choices=available_backends, required=True, help=f"Analysis backend to use (available: {', '.join(available_backends)})") + parser.add_argument("--save", action="store_true", help="Save changes to database/project") + parser.add_argument("--auto-analysis", action="store_true", help="Run auto analysis (default: False)") + parser.add_argument("--force", action="store_true", help="Remove previous artifacts and re-analyze") + parser.add_argument("-L", "--logfile", help="Output log file path") + + args = parser.parse_args() + + if not available_backends: + print("[x] Error: No analysis backends available. Please install IDA Pro, Binary Ninja, or Ghidra.") + sys.exit(1) + + file_path = args.file.resolve() + if not file_path.exists(): + print(f"[x] Error: File not found: {file_path}") + sys.exit(1) + + # Store original streams for cleanup + original_stdout = sys.stdout + original_stderr = sys.stderr + log_file_handle = None + + # Redirect logs if specified + if args.logfile: + log_file = Path(args.logfile).resolve() + print(f"[+] Redirecting logs to: {log_file}") + log_file_handle = open(log_file, "w") + sys.stdout = log_file_handle + sys.stderr = log_file_handle + + try: + print(f"[+] Starting XRefer analysis with {args.backend} backend") + print(f"[+] File: {file_path}") + print(f"[+] Auto-analysis: {args.auto_analysis}") + print(f"[+] Save changes: {args.save}") + print(f"[+] Force re-analysis: {args.force}") + + try: + if args.backend == "ida": + _analyze_ida(file_path, args.auto_analysis, args.save, args.force) + elif args.backend == "binaryninja": + _analyze_binaryninja(file_path, args.auto_analysis, args.save, args.force) + elif args.backend == "ghidra": + _analyze_ghidra(file_path, args.auto_analysis, args.save, args.force) + else: + print(f"[x] Error: Unknown backend: {args.backend}") + sys.exit(1) + print("[+] Analysis completed successfully") + except KeyboardInterrupt: + print("\n[!] Analysis interrupted by user") + sys.exit(1) + except Exception as e: + print(f"\n[x] Analysis failed: {e}") + traceback.print_exc() + sys.exit(1) + finally: + # Restore original streams and close log file + sys.stdout = original_stdout + sys.stderr = original_stderr + if log_file_handle: + log_file_handle.close() + + +def main(): + cli() + + +if __name__ == "__main__": + main()