diff --git a/kustomize/README.md b/kustomize/README.md index 3ca9e4b7..e98827c8 100644 --- a/kustomize/README.md +++ b/kustomize/README.md @@ -131,13 +131,16 @@ EOF 6. Create the `oauth-secret.env` file containing the `client-secret` and `openshift-domain` values required by the [ExploitIQ Client](./base/exploit_iq_client.yaml) configuration. -Replace `some-long-secret-used-by-the-oauth-client` with a more secure, unique secret +If openshift resource of kind `OAuthClient` named `exploit-iq-client` exists, just get the secret from there: +```shell +export OAUTH_CLIENT_SECRET=$(oc get oauthclient exploit-iq-client -o jsonpath='{..secret}') +``` +Otherwise, Replace `some-long-secret-used-by-the-oauth-client` with a more secure, unique secret of your own: ```shell export OAUTH_CLIENT_SECRET="some-long-secret-used-by-the-oauth-client" ``` - ```shell cat > base/oauth-secrets.env << EOF client-secret=$OAUTH_CLIENT_SECRET @@ -184,7 +187,7 @@ redirectURIs: - "http://$(oc get route exploit-iq-client -o jsonpath='{.spec.host}')/app/index.html" EOF ``` -Otherwise, just add your route to the existing `OAuthClient` CR object: +Otherwise ( if creating `OAuthClient` instance got error because it's already exists in the cluster) , just add your route to the existing `OAuthClient` CR object: ```shell export HTTPS_ROUTE=https://$(oc get route exploit-iq-client -o jsonpath='{.spec.host}')/app/index.html export HTTP_ROUTE=http://$(oc get route exploit-iq-client -o jsonpath='{.spec.host}')/app/index.html diff --git a/src/vuln_analysis/functions/cve_agent.py b/src/vuln_analysis/functions/cve_agent.py index 93eed149..37a65cb1 100644 --- a/src/vuln_analysis/functions/cve_agent.py +++ b/src/vuln_analysis/functions/cve_agent.py @@ -33,6 +33,7 @@ from vuln_analysis.data_models.state import AgentMorpheusEngineState from vuln_analysis.tools.tool_names import ToolNames from vuln_analysis.utils.error_handling_decorator import ToolRaisedException +from vuln_analysis.utils.functions_parsers.lang_functions_parsers_factory import get_language_function_parser from vuln_analysis.utils.prompting import get_agent_prompt from vuln_analysis.logging.loggers_factory import LoggingFactory, trace_id @@ -74,7 +75,17 @@ async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder, tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN) llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - + + ecosystem = state.original_input.input.image.ecosystem + transitive_search_tool_supports_ecosystem: bool = True + # I + if ecosystem: + try: + get_language_function_parser(ecosystem, None) + except NotImplementedError: + transitive_search_tool_supports_ecosystem = False + logger.warning(f"Transitive code search tool doesn't support programming language {ecosystem}," + f" disabling tool...") # Filter tools that are not available based on state tools = [ tool for tool in tools diff --git a/src/vuln_analysis/functions/cve_generate_vdbs.py b/src/vuln_analysis/functions/cve_generate_vdbs.py index 53e4ca6f..d342dc09 100644 --- a/src/vuln_analysis/functions/cve_generate_vdbs.py +++ b/src/vuln_analysis/functions/cve_generate_vdbs.py @@ -29,6 +29,7 @@ from vuln_analysis.data_models.common import AnalysisType from vuln_analysis.logging.loggers_factory import LoggingFactory, trace_id from vuln_analysis.tools.tool_names import ToolNames +from vuln_analysis.utils.dep_tree import Ecosystem logger = LoggingFactory.get_agent_logger(__name__) @@ -71,7 +72,7 @@ async def generate_vdb(config: CVEGenerateVDBsToolConfig, builder: Builder): from vuln_analysis.utils.source_rpm_downloader import RPMDependencyManager from vuln_analysis.data_models.input import ManualSBOMInfoInput from vuln_analysis.utils.standard_library_cache import StandardLibraryCache - + agent_config = builder.get_function_config(config.agent_name) assert isinstance(agent_config, CVEAgentExecutorToolConfig) @@ -86,8 +87,8 @@ async def generate_vdb(config: CVEGenerateVDBsToolConfig, builder: Builder): config.ignore_code_index = True embedding = await builder.get_embedder(embedder_name=config.embedder_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) - - # Configure RPM singleton with cache directory from config + + # Configure RPM singleton with cache directory from config rpm_manager = RPMDependencyManager.get_instance() rpm_manager.set_rpm_cache_dir(config.base_rpm_dir) cache_std = StandardLibraryCache.get_instance() @@ -99,7 +100,7 @@ async def generate_vdb(config: CVEGenerateVDBsToolConfig, builder: Builder): pickle_cache_directory=config.base_pickle_dir) def _create_code_index(source_infos: list[SourceDocumentsInfo], embedder: DocumentEmbedding, - output_path: Path) -> bool : + output_path: Path) -> bool: logger.info("Collecting documents from git repos. Source Infos: %s", json.dumps([x.model_dump(mode="json") for x in source_infos])) @@ -136,7 +137,8 @@ def _create_code_index(source_infos: list[SourceDocumentsInfo], embedder: Docume logger.info("Completed code indexing in %.2f seconds for '%s'", time.time() - indexing_start_time, output_path) return True - def _build_code_index(source_infos: list[SourceDocumentsInfo]) -> Path | None: + def _build_code_index(source_infos: list[SourceDocumentsInfo], ecosystem: Ecosystem = None, + manifest_path: str = None) -> Path | None: code_index_path: Path | None = None # Filter to only code sources @@ -148,7 +150,9 @@ def _build_code_index(source_infos: list[SourceDocumentsInfo]) -> Path | None: embedder = DocumentEmbedding(embedding=None, vdb_directory=config.base_vdb_dir, git_directory=config.base_git_dir, - pickle_cache_directory=config.base_pickle_dir) + pickle_cache_directory=config.base_pickle_dir, + manifest_path=manifest_path, + ecosystem=ecosystem) # Determine code index path for either loading from cache or creating new index # Need to add support for configurable base path @@ -189,6 +193,8 @@ async def _arun(message: AgentMorpheusInput) -> AgentMorpheusEngineInput: base_image = message.image.name source_infos = message.image.source_info sbom_infos = message.image.sbom_info + ecosystem = message.image.ecosystem + manifest_path = message.image.manifest_path try: trace_id.set(message.scan.id) @@ -209,13 +215,13 @@ async def _arun(message: AgentMorpheusInput) -> AgentMorpheusEngineInput: # Build code index if not ignored if not config.ignore_code_index: - logger.info("analysis type: %s", message.image.analysis_type) + logger.info("analysis type: %s", message.image.analysis_type) if message.image.analysis_type == AnalysisType.IMAGE and isinstance(sbom_infos, ManualSBOMInfoInput): RPMDependencyManager.get_instance().sbom = sbom_infos.packages image = f"{message.image.name}:{message.image.tag}" RPMDependencyManager.get_instance().container_image = image - - code_index_path = _build_code_index(source_infos) + + code_index_path = _build_code_index(source_infos, ecosystem, manifest_path) if code_index_path is None: logger.warning(("Failed to generate code index for image '%s'. " diff --git a/src/vuln_analysis/tools/tests/test_transitive_code_search.py b/src/vuln_analysis/tools/tests/test_transitive_code_search.py index 5e1a5611..5c4be707 100644 --- a/src/vuln_analysis/tools/tests/test_transitive_code_search.py +++ b/src/vuln_analysis/tools/tests/test_transitive_code_search.py @@ -130,6 +130,7 @@ async def get_transitive_code_runner_function(): async for function in transitive_code_search.gen: return function.single_fn + python_dependency_tree_mock_output = ( 'deptree==0.0.12 # deptree\n' ' importlib-metadata==8.7.0 # importlib-metadata\n' @@ -147,6 +148,7 @@ async def get_transitive_code_runner_function(): ' mock-package==1.1.1' ) + def mock_file_open(*args, **kwargs): file_path = args[0] if args else kwargs.get('file', '') mock_file = MagicMock() @@ -179,14 +181,16 @@ def mock_file_open(*args, **kwargs): "search_query": "werkzeug,formparser.MultiPartParser.parse", "expected_path_found": False, "expected_list_length": 0, - "mock_documents": [python_script_example, python_init_function_example, python_full_document_example, python_parse_function_example] + "mock_documents": [python_script_example, python_init_function_example, python_full_document_example, + python_parse_function_example] }, { "name": "python_3", "search_query": "mock_package,mock_function_in_use", "expected_path_found": True, "expected_list_length": 3, - "mock_documents": [python_script_example, python_init_function_example, python_full_document_example, python_parse_function_example, python_mock_function_in_use, python_mock_file] + "mock_documents": [python_script_example, python_init_function_example, python_full_document_example, + python_parse_function_example, python_mock_function_in_use, python_mock_file] } ]) @patch('vuln_analysis.utils.dep_tree.run_command', return_value=python_dependency_tree_mock_output) @@ -204,7 +208,7 @@ async def test_transitive_search_python_parameterized(mock_open, mock_run_comman ) with patch('vuln_analysis.utils.document_embedding.retrieve_from_cache', - return_value=(test_case["mock_documents"], True)): + return_value=(test_case["mock_documents"], True)): result = await transitive_code_search_runner_coroutine(test_case["search_query"]) (path_found, list_path) = result @@ -215,6 +219,7 @@ async def test_transitive_search_python_parameterized(mock_open, mock_run_comman assert path_found == test_case["expected_path_found"] assert len(list_path) == test_case["expected_list_length"] + @pytest.mark.asyncio async def test_python_transitive_search(): """Test that runs with a real repository""" @@ -234,7 +239,7 @@ async def test_python_transitive_search(): print(f"DEBUG: path_found = {path_found}") print(f"DEBUG: list_path = {list_path}") print(f"DEBUG: len(list_path) = {len(list_path)}") - assert path_found == True + assert path_found is True assert len(list_path) == 2 @pytest.mark.asyncio @@ -395,4 +400,4 @@ async def test_transitive_search_java_2(): # assert path_found is True # assert len(list_path) > 1 # document = list_path[-1] -# assert 'src/main/java/io/cryostat' in document.metadata['source'] \ No newline at end of file +# assert 'src/main/java/io/cryostat' in document.metadata['source'] diff --git a/src/vuln_analysis/tools/transitive_code_search.py b/src/vuln_analysis/tools/transitive_code_search.py index 64bae400..6c312113 100644 --- a/src/vuln_analysis/tools/transitive_code_search.py +++ b/src/vuln_analysis/tools/transitive_code_search.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os from vuln_analysis.runtime_context import ctx_state @@ -25,11 +26,11 @@ from langchain.docstore.document import Document from vuln_analysis.data_models.state import AgentMorpheusEngineState -from vuln_analysis.utils.document_embedding import DocumentEmbedding from ..data_models.input import SourceDocumentsInfo -from ..utils.chain_of_calls_retriever_base import ChainOfCallsRetrieverBase +from ..utils.chain_of_calls_retriever import ChainOfCallsRetriever +from vuln_analysis.utils.dep_tree import Ecosystem +from vuln_analysis.utils.document_embedding import DocumentEmbedding from ..utils.chain_of_calls_retriever_factory import get_chain_of_calls_retriever -from ..utils.dep_tree import Ecosystem from ..utils.error_handling_decorator import catch_pipeline_errors_async, catch_tool_errors from ..utils.function_name_extractor import FunctionNameExtractor from ..utils.function_name_locator import FunctionNameLocator @@ -37,7 +38,6 @@ from vuln_analysis.logging.loggers_factory import LoggingFactory from ..utils.java_chain_of_calls_retriever import JavaChainOfCallsRetriever - PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME = "package_and_function_locator" FUNCTION_NAME_EXTRACTOR_TOOL_NAME = "calling_function_name_extractor" @@ -46,7 +46,7 @@ logger = LoggingFactory.get_agent_logger(__name__) -class TransitiveCodeSearchToolConfig(FunctionBaseConfig, name=("%s" % TRANSITIVE_CODE_SEARCH_TOOL_NAME)): +class TransitiveCodeSearchToolConfig(FunctionBaseConfig, name=TRANSITIVE_CODE_SEARCH_TOOL_NAME): """ Transitive code search tool used to search source code. """ @@ -58,12 +58,12 @@ class CallingFunctionNameExtractorToolConfig(FunctionBaseConfig, name=FUNCTION_N """ -class PackageAndFunctionLocatorToolConfig(FunctionBaseConfig, name=("%s" % PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME)): +class PackageAndFunctionLocatorToolConfig(FunctionBaseConfig, name=PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME): """ Package and function locator tool used to validate package names and find function names using fuzzy matching. """ -def get_call_of_chains_retriever(documents_embedder, si, query: str): +def get_call_of_chains_retriever(documents_embedder, si, query: str, ecosystem, manifest_path : str): documents: list[Document] git_repo = None code_source_info: SourceDocumentsInfo @@ -72,24 +72,34 @@ def get_call_of_chains_retriever(documents_embedder, si, query: str): code_source_info = source_info git_repo = documents_embedder.get_repo_path(source_info) documents = documents_embedder.collect_documents(source_info) + if git_repo is None: raise ValueError("No code source info found") - with open(os.path.join(git_repo, 'ecosystem_data.txt'), 'r', encoding='utf-8') as file: - ecosystem = file.read() - ecosystem = Ecosystem[ecosystem] + if not ecosystem: + with open(os.path.join(git_repo, 'ecosystem_data.txt'), 'r', encoding='utf-8') as file: + ecosystem = file.read() + ecosystem = Ecosystem[ecosystem.upper()] + # path_to_manifest = git_repo_path.joinpath(manifest_path) + + if manifest_path: + git_repo_with_manifest = git_repo.joinpath(manifest_path) + else: + git_repo_with_manifest = git_repo coc_retriever = get_chain_of_calls_retriever(ecosystem=ecosystem, documents=documents, - manifest_path=git_repo, + manifest_path=git_repo_with_manifest, query=query, code_source_info=code_source_info) return coc_retriever + def get_transitive_code_searcher(query: str): state: AgentMorpheusEngineState = ctx_state.get() - if state.transitive_code_searcher is None or isinstance(state.transitive_code_searcher.chain_of_calls_retriever, JavaChainOfCallsRetriever): + if (state.transitive_code_searcher is None or + isinstance(state.transitive_code_searcher.chain_of_calls_retriever, JavaChainOfCallsRetriever)): si = state.original_input.input.image.source_info documents_embedder = DocumentEmbedding(embedding=None) - coc_retriever = get_call_of_chains_retriever(documents_embedder, si, query) + coc_retriever = get_call_of_chains_retriever(documents_embedder, si, query, state.original_input.input.image.ecosystem , state.original_input.input.image.manifest_path) transitive_code_searcher = TransitiveCodeSearcher(chain_of_calls_retriever=coc_retriever) state.transitive_code_searcher = transitive_code_searcher return state.transitive_code_searcher @@ -167,7 +177,7 @@ async def package_and_function_locator(config: PackageAndFunctionLocatorToolConf builder: Builder): # pylint: disable=unused-argument """ Function Locator tool used to validate package names and find function names using fuzzy matching. - Mandatory first step for code path analysis. + Mandatory first step for code path analysis. """ @catch_tool_errors(PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME) diff --git a/src/vuln_analysis/utils/document_embedding.py b/src/vuln_analysis/utils/document_embedding.py index 9ccd55d9..bf392a74 100644 --- a/src/vuln_analysis/utils/document_embedding.py +++ b/src/vuln_analysis/utils/document_embedding.py @@ -15,7 +15,9 @@ import copy import json +import logging import os +import pickle import sys import time import typing @@ -37,12 +39,14 @@ from vuln_analysis.data_models.input import SourceDocumentsInfo from vuln_analysis.utils.data_utils import retrieve_from_cache, save_to_cache, DEFAULT_PICKLE_CACHE_DIRECTORY, DEFAULT_GIT_DIRECTORY, \ VDB_DIRECTORY, PathLike +from vuln_analysis.utils.dep_tree import Ecosystem from vuln_analysis.utils.go_segmenters_with_methods import GoSegmenterWithMethods from vuln_analysis.utils.python_segmenters_with_classes_methods import PythonSegmenterWithClassesMethods from vuln_analysis.utils.java_segmenters_with_methods import JavaSegmenterWithMethods from vuln_analysis.utils.js_extended_parser import ExtendedJavaScriptSegmenter from vuln_analysis.utils.source_code_git_loader import SourceCodeGitLoader from vuln_analysis.utils.git_utils import sanitize_git_url_for_path +from vuln_analysis.utils.transitive_code_searcher_tool import TransitiveCodeSearcher from vuln_analysis.logging.loggers_factory import LoggingFactory from vuln_analysis.utils.c_segmenter_custom import CSegmenterExtended @@ -50,6 +54,9 @@ if typing.TYPE_CHECKING: from langchain_core.embeddings import Embeddings # pragma: no cover + + + logger = LoggingFactory.get_agent_logger(__name__) class MultiLanguageRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @@ -227,7 +234,8 @@ class DocumentEmbedding: def __init__(self, *, embedding: "Embeddings", vdb_directory: PathLike = VDB_DIRECTORY, git_directory: PathLike = DEFAULT_GIT_DIRECTORY, chunk_size: int = 800, chunk_overlap: int = 160, - pickle_cache_directory: PathLike = DEFAULT_PICKLE_CACHE_DIRECTORY): + pickle_cache_directory: PathLike = DEFAULT_PICKLE_CACHE_DIRECTORY, ecosystem: Ecosystem = None, + manifest_path: str = None): """ Create a new DocumentEmbedding instance. @@ -253,6 +261,8 @@ def __init__(self, *, embedding: "Embeddings", vdb_directory: PathLike = VDB_DIR self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._pickle_cache_directory = Path(pickle_cache_directory) + self._ecosystem = ecosystem + self._manifest_path = manifest_path @property def embedding(self): @@ -366,7 +376,9 @@ def collect_documents(self, source_info: SourceDocumentsInfo) -> list[Document]: clone_url=source_info.git_repo, ref=source_info.ref, include=source_info.include, - exclude=source_info.exclude) + exclude=source_info.exclude, + manifest_path=self._manifest_path, + ecosystem=self._ecosystem) blob_parser = ExtendedLanguageParser() loader = GenericLoader(blob_loader=blob_loader, blob_parser=blob_parser) diff --git a/src/vuln_analysis/utils/functions_parsers/lang_functions_parsers_factory.py b/src/vuln_analysis/utils/functions_parsers/lang_functions_parsers_factory.py index 4f1178c6..4e3f3386 100644 --- a/src/vuln_analysis/utils/functions_parsers/lang_functions_parsers_factory.py +++ b/src/vuln_analysis/utils/functions_parsers/lang_functions_parsers_factory.py @@ -12,8 +12,7 @@ def get_language_function_parser(ecosystem: Ecosystem, tree: DependencyTree | No :param ecosystem: the desired programming language parser. :param tree: the dependency tree for the ecosystem, can be None. :return: - The right language functions parser associated to the ecosystem, if not exists, return ABC parent that - doesn't do anything + The right language functions parser associated to the ecosystem, if not exists, throw an NotImplementedError """ if ecosystem == Ecosystem.GO: return GoLanguageFunctionsParser() @@ -27,4 +26,4 @@ def get_language_function_parser(ecosystem: Ecosystem, tree: DependencyTree | No elif ecosystem == Ecosystem.JAVA: return JavaLanguageFunctionsParser() else: - return LanguageFunctionsParser() + raise NotImplementedError(f"Language functions parser for {ecosystem} not implemented.") diff --git a/src/vuln_analysis/utils/source_code_git_loader.py b/src/vuln_analysis/utils/source_code_git_loader.py index ae2b089d..0fb677b2 100644 --- a/src/vuln_analysis/utils/source_code_git_loader.py +++ b/src/vuln_analysis/utils/source_code_git_loader.py @@ -25,6 +25,7 @@ from langchain_core.document_loaders.blob_loaders import Blob from tqdm import tqdm +from vuln_analysis.utils.dep_tree import Ecosystem from vuln_analysis.utils.transitive_code_searcher_tool import TransitiveCodeSearcher from vuln_analysis.logging.loggers_factory import LoggingFactory @@ -47,14 +48,11 @@ class SourceCodeGitLoader(BlobLoader): files from. By default, it loads from the `main` branch. """ - def __init__( - self, - repo_path: PathLike, - clone_url: str | None = None, - ref: typing.Optional[str] = "main", - include: typing.Optional[typing.Iterable[str]] = None, - exclude: typing.Optional[typing.Iterable[str]] = None, - ): + def __init__(self, repo_path: PathLike, clone_url: str | None = None, ref: typing.Optional[str] = "main", + include: typing.Optional[typing.Iterable[str]] = None, + exclude: typing.Optional[typing.Iterable[str]] = None, + manifest_path: str = None, + ecosystem: Ecosystem = None): """ Initialize the Git loader. @@ -70,6 +68,8 @@ def __init__( A list of file patterns to include. Uses the glob syntax, by default None exclude : typing.Optional[typing.Iterable[str]], optional A list of file patterns to exclude. Uses the glob syntax, by default None + :param manifest_path: + :param ecosystem: """ self.repo_path = Path(repo_path) @@ -80,6 +80,8 @@ def __init__( self.exclude = exclude self._repo: Repo | None = None + self._manifest_path = manifest_path + self._ecosystem = ecosystem def load_repo(self): """ @@ -132,13 +134,14 @@ def load_repo(self): repo.git.fetch("origin", self.ref, "--depth=1", "--force") tag_refspec = f"refs/tags/{self.ref}:refs/tags/{self.ref}" try: - repo.git.fetch("origin", tag_refspec, "--depth=1" , "--force") + repo.git.fetch("origin", tag_refspec, "--depth=1", "--force") except GitCommandError: pass repo.git.checkout(self.ref, "--force") logger.info("Loaded Git repository at path: '%s' @ '%s'", self.repo_path, self.ref) - TransitiveCodeSearcher.download_dependencies(self.repo_path) + TransitiveCodeSearcher.download_dependencies(self.repo_path, manifest_path= self._manifest_path, + the_ecosystem=self._ecosystem) self._repo = repo return repo diff --git a/src/vuln_analysis/utils/transitive_code_searcher_tool.py b/src/vuln_analysis/utils/transitive_code_searcher_tool.py index 7c5053f0..21aeb34f 100644 --- a/src/vuln_analysis/utils/transitive_code_searcher_tool.py +++ b/src/vuln_analysis/utils/transitive_code_searcher_tool.py @@ -37,6 +37,13 @@ logger = LoggingFactory.get_agent_logger(f"morpheus.{__name__}") +def determine_manifest_name_by_ecosystem(the_ecosystem): + for manifest_name, ecosystem in MANIFESTS_TO_ECOSYSTEMS.items(): + if ecosystem == the_ecosystem: + return manifest_name + + return None + class TransitiveCodeSearcher: """ Transitive code Searcher for code using a Chain Of Calls Retriever object @@ -52,19 +59,37 @@ def __init__(self, chain_of_calls_retriever: ChainOfCallsRetrieverBase): self.chain_of_calls_retriever = chain_of_calls_retriever @staticmethod - def download_dependencies(git_repo_path: Path) -> bool: + def download_dependencies(git_repo_path: Path, manifest_path: str = None, the_ecosystem: Ecosystem = None) -> bool: """ Download all dependencies according to manifest file in the Git repository Parameters ---------- git_repo_path : Path Git repository path to fetch the application manifests from + manifest_path: str + path to manifest file within the Git repository Returns whether dependencies were downloaded or not. + :param the_ecosystem: + :param git_repo_path: + :param manifest_path: """ ecosystem: Ecosystem # Check the root dir of the repo for existence of manifests, the precedence of which manifest file t o check is # according to the order from top to bottom - if os.path.isfile(git_repo_path / GOLANG_MANIFEST): + path_to_manifest: Path + # If manifest path is supplied in input, override the default root repo dir as dir of manifest file with this value. + if manifest_path: + path_to_manifest = git_repo_path.joinpath(manifest_path) + logger.info(f"manifest_path field supplied in request payload, overriding default value of " + f"root directory of repository." + f" relative manifest_path value => {manifest_path}, path_to_manifest=>{path_to_manifest}") + else: + path_to_manifest = git_repo_path + # If ecosystem is supplied in input, then override default of first found ecosystem manifest in the repo. + if the_ecosystem and os.path.isfile(path_to_manifest / determine_manifest_name_by_ecosystem(the_ecosystem)): + ecosystem = the_ecosystem + logger.info(f"Ecosystem field supplied in request payload, ecosystem value => {ecosystem}") + elif os.path.isfile(path_to_manifest / GOLANG_MANIFEST): ecosystem = MANIFESTS_TO_ECOSYSTEMS[GOLANG_MANIFEST] elif os.path.isfile(git_repo_path / PYTHON_MANIFEST): ecosystem = MANIFESTS_TO_ECOSYSTEMS[PYTHON_MANIFEST] @@ -72,6 +97,7 @@ def download_dependencies(git_repo_path: Path) -> bool: ecosystem = MANIFESTS_TO_ECOSYSTEMS[JS_MANIFEST] elif os.path.isfile(git_repo_path / JAVA_MANIFEST): ecosystem = MANIFESTS_TO_ECOSYSTEMS[JAVA_MANIFEST] + # Search for C/C++ manifest else: # 1. Direct checks candidates = [ @@ -83,7 +109,7 @@ def download_dependencies(git_repo_path: Path) -> bool: C_CPLUSPLUS_MANIFEST_4) ] found = [p for p in candidates if p.is_file()] - if found: + if found: ecosystem = MANIFESTS_TO_ECOSYSTEMS[C_CPLUSPLUS_MANIFEST_1] else: logger.info(f"Didn't find manifest to install, skipping.. {git_repo_path}") @@ -93,8 +119,9 @@ def download_dependencies(git_repo_path: Path) -> bool: logger.info(f"Started installing packages for {ecosystem}") tree_builder = get_dependency_tree_builder(ecosystem) tree_builder.install_dependencies(git_repo_path) - with open(os.path.join(git_repo_path, 'ecosystem_data.txt'), 'w') as file: - file.write(ecosystem.name) + if not the_ecosystem: + with open(os.path.join(git_repo_path, 'ecosystem_data.txt'), 'w') as file: + file.write(ecosystem.name) logger.info(f"Finished installing packages for {ecosystem}") return True except NotImplementedError as err: @@ -140,4 +167,4 @@ def search(self, query: str) -> tuple[bool, list[Document]]: f"-------------------------------------------\n{function_method.page_content}\n" logger.debug(content_of_files_in_path) logger.debug(content_of_files_in_path, extra=MULTI_LINE_MESSAGE_TRUE) - return found_path, call_hierarchy_list \ No newline at end of file + return found_path, call_hierarchy_list