From e641d7644ad74c5e36f9c66580dd21bae84f3c74 Mon Sep 17 00:00:00 2001 From: haeter525 Date: Sat, 23 Apr 2022 17:44:12 +0800 Subject: [PATCH] Update parser to support Multidex --- quark/core/rzapkinfo.py | 104 ++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 63 deletions(-) diff --git a/quark/core/rzapkinfo.py b/quark/core/rzapkinfo.py index 446ab39a..97ddc987 100644 --- a/quark/core/rzapkinfo.py +++ b/quark/core/rzapkinfo.py @@ -20,7 +20,7 @@ from quark.core.struct.methodobject import MethodObject from quark.utils.tools import descriptor_to_androguard_format, remove_dup_list -RizinCache = namedtuple("rizin_cache", "address dexindex is_imported") +RizinCache = namedtuple("rizin_cache", "address is_imported") PRIMITIVE_TYPE_MAPPING = { "void": "V", @@ -54,9 +54,10 @@ def __init__( ): super().__init__(apk_filepath, "rizin") + self._binary_path = apk_filepath + if self.ret_type == "DEX": self._tmp_dir = None - self._dex_list = [apk_filepath] elif self.ret_type == "APK": self._tmp_dir = tempfile.mkdtemp() if tmp_dir is None else tmp_dir @@ -64,28 +65,21 @@ def __init__( with zipfile.ZipFile(self.apk_filepath) as apk: apk.extract("AndroidManifest.xml", path=self._tmp_dir) - self._manifest = os.path.join(self._tmp_dir, "AndroidManifest.xml") - - dex_files = [ - file - for file in apk.namelist() - if file.startswith("classes") and file.endswith(".dex") - ] - - for dex in dex_files: - apk.extract(dex, path=self._tmp_dir) - - self._dex_list = [os.path.join(self._tmp_dir, dex) for dex in dex_files] + self._manifest = os.path.join( + self._tmp_dir, "AndroidManifest.xml" + ) else: raise ValueError("Unsupported File type.") - self._number_of_dex = len(self._dex_list) - @functools.lru_cache - def _get_rz(self, index): - rz = rzpipe.open(self._dex_list[index]) - rz.cmd("aa") + def _get_rz(self): + if self.ret_type == "DEX": + rz = rzpipe.open(self._binary_path) + elif self.ret_type == "APK": + rz = rzpipe.open(f"apk://{self._binary_path}") + rz.cmd("aaa") + return rz def _convert_type_to_type_signature(self, raw_type: str): @@ -112,8 +106,8 @@ def _escape_str_in_rizin_manner(raw_str: str): return raw_str @functools.lru_cache - def _get_methods_classified(self, dexindex): - rz = self._get_rz(dexindex) + def _get_methods_classified(self): + rz = self._get_rz() method_json_list = rz.cmdj("isj") method_dict = defaultdict(list) @@ -182,7 +176,7 @@ def _get_methods_classified(self, dexindex): class_name="", name="clone", descriptor="()Ljava/lang/Object;", - cache=RizinCache(json_obj["vaddr"], dexindex, is_imported), + cache=RizinCache(json_obj["vaddr"], is_imported), ) method_dict[""].append(method) continue @@ -215,7 +209,7 @@ def _get_methods_classified(self, dexindex): class_name=class_name, name=method_name, descriptor=descriptor, - cache=RizinCache(json_obj["vaddr"], dexindex, is_imported), + cache=RizinCache(json_obj["vaddr"], is_imported), ) method_dict[class_name].append(method) @@ -256,9 +250,8 @@ def custom_methods(self) -> Set[MethodObject]: @functools.cached_property def all_methods(self) -> Set[MethodObject]: method_set = set() - for dex_index in range(self._number_of_dex): - for method_list in self._get_methods_classified(dex_index).values(): - method_set.update(method_list) + for method_list in self._get_methods_classified().values(): + method_set.update(method_list) return method_set @@ -273,21 +266,15 @@ def method_filter(method): not descriptor or descriptor == method.descriptor ) - dex_list = range(self._number_of_dex) - - for dex_index in dex_list: - method_dict = self._get_methods_classified(dex_index) - filtered_methods = filter(method_filter, method_dict[class_name]) - try: - return next(filtered_methods) - except StopIteration: - continue + method_dict = self._get_methods_classified() + filtered_methods = filter(method_filter, method_dict[class_name]) + return next(filtered_methods, None) @functools.lru_cache def upperfunc(self, method_object: MethodObject) -> Set[MethodObject]: cache = method_object.cache - r2 = self._get_rz(cache.dexindex) + r2 = self._get_rz() xrefs = r2.cmdj(f"axtj @ {cache.address}") @@ -317,7 +304,7 @@ def upperfunc(self, method_object: MethodObject) -> Set[MethodObject]: def lowerfunc(self, method_object: MethodObject) -> Set[MethodObject]: cache = method_object.cache - r2 = self._get_rz(cache.dexindex) + r2 = self._get_rz() xrefs = r2.cmdj(f"axffj @ {cache.address}") @@ -360,7 +347,7 @@ def get_method_bytecode( if not cache.is_imported: - rz = self._get_rz(cache.dexindex) + rz = self._get_rz() instruct_flow = rz.cmdj(f"pdfj @ {cache.address}")["ops"] @@ -369,16 +356,11 @@ def get_method_bytecode( yield self._parse_smali(ins["disasm"]) def get_strings(self) -> Set[str]: - strings = set() - for dex_index in range(self._number_of_dex): - rz = self._get_rz(dex_index) - string_detail_list = rz.cmdj("izzj") - strings.update( - [string_detail["string"] for string_detail in string_detail_list] - ) + rz = self._get_rz() + string_detail_list = rz.cmdj("izzj") - return strings + return {string_obj["string"] for string_obj in string_detail_list} def get_wrapper_smali( self, @@ -413,7 +395,7 @@ def convert_bytecode_to_list(bytecode): if cache.is_imported: return {} - rz = self._get_rz(cache.dexindex) + rz = self._get_rz() instruction_flow = rz.cmdj(f"pdfj @ {cache.address}")["ops"] @@ -451,16 +433,14 @@ def convert_bytecode_to_list(bytecode): def superclass_relationships(self) -> Dict[str, Set[str]]: hierarchy_dict = defaultdict(set) - for dex_index in range(self._number_of_dex): + rz = self._get_rz() - rz = self._get_rz(dex_index) + class_info_list = rz.cmdj("icj") + for class_info in class_info_list: + class_name = class_info["classname"] + super_class = class_info["super"] - class_info_list = rz.cmdj("icj") - for class_info in class_info_list: - class_name = class_info["classname"] - super_class = class_info["super"] - - hierarchy_dict[class_name].add(super_class) + hierarchy_dict[class_name].add(super_class) return hierarchy_dict @@ -468,16 +448,14 @@ def superclass_relationships(self) -> Dict[str, Set[str]]: def subclass_relationships(self) -> Dict[str, Set[str]]: hierarchy_dict = defaultdict(set) - for dex_index in range(self._number_of_dex): - - rz = self._get_rz(dex_index) + rz = self._get_rz() - class_info_list = rz.cmdj("icj") - for class_info in class_info_list: - class_name = class_info["classname"] - super_class = class_info["super"] + class_info_list = rz.cmdj("icj") + for class_info in class_info_list: + class_name = class_info["classname"] + super_class = class_info["super"] - hierarchy_dict[super_class].add(class_name) + hierarchy_dict[super_class].add(class_name) return hierarchy_dict