Skip to content
Open
9 changes: 6 additions & 3 deletions kustomize/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,16 @@ EOF

6. Create the `oauth-secret.env` file containing the `client-secret` and `openshift-domain` values required by the [ExploitIQ Client](./base/exploit_iq_client.yaml) configuration.

Replace `some-long-secret-used-by-the-oauth-client` with a more secure, unique secret
If openshift resource of kind `OAuthClient` named `exploit-iq-client` exists, just get the secret from there:
```shell
export OAUTH_CLIENT_SECRET=$(oc get oauthclient exploit-iq-client -o jsonpath='{..secret}')
```
Otherwise, Replace `some-long-secret-used-by-the-oauth-client` with a more secure, unique secret of your own:

```shell
export OAUTH_CLIENT_SECRET="some-long-secret-used-by-the-oauth-client"
```


```shell
cat > base/oauth-secrets.env << EOF
client-secret=$OAUTH_CLIENT_SECRET
Expand Down Expand Up @@ -184,7 +187,7 @@ redirectURIs:
- "http://$(oc get route exploit-iq-client -o jsonpath='{.spec.host}')/app/index.html"
EOF
```
Otherwise, just add your route to the existing `OAuthClient` CR object:
Otherwise ( if creating `OAuthClient` instance got error because it's already exists in the cluster) , just add your route to the existing `OAuthClient` CR object:
```shell
export HTTPS_ROUTE=https://$(oc get route exploit-iq-client -o jsonpath='{.spec.host}')/app/index.html
export HTTP_ROUTE=http://$(oc get route exploit-iq-client -o jsonpath='{.spec.host}')/app/index.html
Expand Down
13 changes: 12 additions & 1 deletion src/vuln_analysis/functions/cve_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vuln_analysis.data_models.state import AgentMorpheusEngineState
from vuln_analysis.tools.tool_names import ToolNames
from vuln_analysis.utils.error_handling_decorator import ToolRaisedException
from vuln_analysis.utils.functions_parsers.lang_functions_parsers_factory import get_language_function_parser
from vuln_analysis.utils.prompting import get_agent_prompt
from vuln_analysis.logging.loggers_factory import LoggingFactory, trace_id

Expand Down Expand Up @@ -74,7 +75,17 @@ async def _create_agent(config: CVEAgentExecutorToolConfig, builder: Builder,

tools = builder.get_tools(tool_names=config.tool_names, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
llm = await builder.get_llm(llm_name=config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)


ecosystem = state.original_input.input.image.ecosystem
transitive_search_tool_supports_ecosystem: bool = True
# I
if ecosystem:
try:
get_language_function_parser(ecosystem, None)
except NotImplementedError:
transitive_search_tool_supports_ecosystem = False
logger.warning(f"Transitive code search tool doesn't support programming language {ecosystem},"
f" disabling tool...")
# Filter tools that are not available based on state
tools = [
tool for tool in tools
Expand Down
24 changes: 15 additions & 9 deletions src/vuln_analysis/functions/cve_generate_vdbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vuln_analysis.data_models.common import AnalysisType
from vuln_analysis.logging.loggers_factory import LoggingFactory, trace_id
from vuln_analysis.tools.tool_names import ToolNames
from vuln_analysis.utils.dep_tree import Ecosystem

logger = LoggingFactory.get_agent_logger(__name__)

Expand Down Expand Up @@ -71,7 +72,7 @@ async def generate_vdb(config: CVEGenerateVDBsToolConfig, builder: Builder):
from vuln_analysis.utils.source_rpm_downloader import RPMDependencyManager
from vuln_analysis.data_models.input import ManualSBOMInfoInput
from vuln_analysis.utils.standard_library_cache import StandardLibraryCache

agent_config = builder.get_function_config(config.agent_name)
assert isinstance(agent_config, CVEAgentExecutorToolConfig)

Expand All @@ -86,8 +87,8 @@ async def generate_vdb(config: CVEGenerateVDBsToolConfig, builder: Builder):
config.ignore_code_index = True

embedding = await builder.get_embedder(embedder_name=config.embedder_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN)
# Configure RPM singleton with cache directory from config

# Configure RPM singleton with cache directory from config
rpm_manager = RPMDependencyManager.get_instance()
rpm_manager.set_rpm_cache_dir(config.base_rpm_dir)
cache_std = StandardLibraryCache.get_instance()
Expand All @@ -99,7 +100,7 @@ async def generate_vdb(config: CVEGenerateVDBsToolConfig, builder: Builder):
pickle_cache_directory=config.base_pickle_dir)

def _create_code_index(source_infos: list[SourceDocumentsInfo], embedder: DocumentEmbedding,
output_path: Path) -> bool :
output_path: Path) -> bool:
logger.info("Collecting documents from git repos. Source Infos: %s",
json.dumps([x.model_dump(mode="json") for x in source_infos]))

Expand Down Expand Up @@ -136,7 +137,8 @@ def _create_code_index(source_infos: list[SourceDocumentsInfo], embedder: Docume
logger.info("Completed code indexing in %.2f seconds for '%s'", time.time() - indexing_start_time, output_path)
return True

def _build_code_index(source_infos: list[SourceDocumentsInfo]) -> Path | None:
def _build_code_index(source_infos: list[SourceDocumentsInfo], ecosystem: Ecosystem = None,
manifest_path: str = None) -> Path | None:
code_index_path: Path | None = None

# Filter to only code sources
Expand All @@ -148,7 +150,9 @@ def _build_code_index(source_infos: list[SourceDocumentsInfo]) -> Path | None:
embedder = DocumentEmbedding(embedding=None,
vdb_directory=config.base_vdb_dir,
git_directory=config.base_git_dir,
pickle_cache_directory=config.base_pickle_dir)
pickle_cache_directory=config.base_pickle_dir,
manifest_path=manifest_path,
ecosystem=ecosystem)

# Determine code index path for either loading from cache or creating new index
# Need to add support for configurable base path
Expand Down Expand Up @@ -189,6 +193,8 @@ async def _arun(message: AgentMorpheusInput) -> AgentMorpheusEngineInput:
base_image = message.image.name
source_infos = message.image.source_info
sbom_infos = message.image.sbom_info
ecosystem = message.image.ecosystem
manifest_path = message.image.manifest_path

try:
trace_id.set(message.scan.id)
Expand All @@ -209,13 +215,13 @@ async def _arun(message: AgentMorpheusInput) -> AgentMorpheusEngineInput:

# Build code index if not ignored
if not config.ignore_code_index:
logger.info("analysis type: %s", message.image.analysis_type)
logger.info("analysis type: %s", message.image.analysis_type)
if message.image.analysis_type == AnalysisType.IMAGE and isinstance(sbom_infos, ManualSBOMInfoInput):
RPMDependencyManager.get_instance().sbom = sbom_infos.packages
image = f"{message.image.name}:{message.image.tag}"
RPMDependencyManager.get_instance().container_image = image
code_index_path = _build_code_index(source_infos)

code_index_path = _build_code_index(source_infos, ecosystem, manifest_path)

if code_index_path is None:
logger.warning(("Failed to generate code index for image '%s'. "
Expand Down
15 changes: 10 additions & 5 deletions src/vuln_analysis/tools/tests/test_transitive_code_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ async def get_transitive_code_runner_function():
async for function in transitive_code_search.gen:
return function.single_fn


python_dependency_tree_mock_output = (
'deptree==0.0.12 # deptree\n'
' importlib-metadata==8.7.0 # importlib-metadata\n'
Expand All @@ -147,6 +148,7 @@ async def get_transitive_code_runner_function():
' mock-package==1.1.1'
)


def mock_file_open(*args, **kwargs):
file_path = args[0] if args else kwargs.get('file', '')
mock_file = MagicMock()
Expand Down Expand Up @@ -179,14 +181,16 @@ def mock_file_open(*args, **kwargs):
"search_query": "werkzeug,formparser.MultiPartParser.parse",
"expected_path_found": False,
"expected_list_length": 0,
"mock_documents": [python_script_example, python_init_function_example, python_full_document_example, python_parse_function_example]
"mock_documents": [python_script_example, python_init_function_example, python_full_document_example,
python_parse_function_example]
},
{
"name": "python_3",
"search_query": "mock_package,mock_function_in_use",
"expected_path_found": True,
"expected_list_length": 3,
"mock_documents": [python_script_example, python_init_function_example, python_full_document_example, python_parse_function_example, python_mock_function_in_use, python_mock_file]
"mock_documents": [python_script_example, python_init_function_example, python_full_document_example,
python_parse_function_example, python_mock_function_in_use, python_mock_file]
}
])
@patch('vuln_analysis.utils.dep_tree.run_command', return_value=python_dependency_tree_mock_output)
Expand All @@ -204,7 +208,7 @@ async def test_transitive_search_python_parameterized(mock_open, mock_run_comman
)

with patch('vuln_analysis.utils.document_embedding.retrieve_from_cache',
return_value=(test_case["mock_documents"], True)):
return_value=(test_case["mock_documents"], True)):
result = await transitive_code_search_runner_coroutine(test_case["search_query"])

(path_found, list_path) = result
Expand All @@ -215,6 +219,7 @@ async def test_transitive_search_python_parameterized(mock_open, mock_run_comman
assert path_found == test_case["expected_path_found"]
assert len(list_path) == test_case["expected_list_length"]


@pytest.mark.asyncio
async def test_python_transitive_search():
"""Test that runs with a real repository"""
Expand All @@ -234,7 +239,7 @@ async def test_python_transitive_search():
print(f"DEBUG: path_found = {path_found}")
print(f"DEBUG: list_path = {list_path}")
print(f"DEBUG: len(list_path) = {len(list_path)}")
assert path_found == True
assert path_found is True
assert len(list_path) == 2

@pytest.mark.asyncio
Expand Down Expand Up @@ -395,4 +400,4 @@ async def test_transitive_search_java_2():
# assert path_found is True
# assert len(list_path) > 1
# document = list_path[-1]
# assert 'src/main/java/io/cryostat' in document.metadata['source']
# assert 'src/main/java/io/cryostat' in document.metadata['source']
38 changes: 24 additions & 14 deletions src/vuln_analysis/tools/transitive_code_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from vuln_analysis.runtime_context import ctx_state
Expand All @@ -25,19 +26,18 @@
from langchain.docstore.document import Document

from vuln_analysis.data_models.state import AgentMorpheusEngineState
from vuln_analysis.utils.document_embedding import DocumentEmbedding
from ..data_models.input import SourceDocumentsInfo
from ..utils.chain_of_calls_retriever_base import ChainOfCallsRetrieverBase
from ..utils.chain_of_calls_retriever import ChainOfCallsRetriever
from vuln_analysis.utils.dep_tree import Ecosystem
from vuln_analysis.utils.document_embedding import DocumentEmbedding
from ..utils.chain_of_calls_retriever_factory import get_chain_of_calls_retriever
from ..utils.dep_tree import Ecosystem
from ..utils.error_handling_decorator import catch_pipeline_errors_async, catch_tool_errors
from ..utils.function_name_extractor import FunctionNameExtractor
from ..utils.function_name_locator import FunctionNameLocator

from vuln_analysis.logging.loggers_factory import LoggingFactory
from ..utils.java_chain_of_calls_retriever import JavaChainOfCallsRetriever


PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME = "package_and_function_locator"

FUNCTION_NAME_EXTRACTOR_TOOL_NAME = "calling_function_name_extractor"
Expand All @@ -46,7 +46,7 @@
logger = LoggingFactory.get_agent_logger(__name__)


class TransitiveCodeSearchToolConfig(FunctionBaseConfig, name=("%s" % TRANSITIVE_CODE_SEARCH_TOOL_NAME)):
class TransitiveCodeSearchToolConfig(FunctionBaseConfig, name=TRANSITIVE_CODE_SEARCH_TOOL_NAME):
"""
Transitive code search tool used to search source code.
"""
Expand All @@ -58,12 +58,12 @@ class CallingFunctionNameExtractorToolConfig(FunctionBaseConfig, name=FUNCTION_N
"""


class PackageAndFunctionLocatorToolConfig(FunctionBaseConfig, name=("%s" % PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME)):
class PackageAndFunctionLocatorToolConfig(FunctionBaseConfig, name=PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME):
"""
Package and function locator tool used to validate package names and find function names using fuzzy matching.
"""

def get_call_of_chains_retriever(documents_embedder, si, query: str):
def get_call_of_chains_retriever(documents_embedder, si, query: str, ecosystem, manifest_path : str):
documents: list[Document]
git_repo = None
code_source_info: SourceDocumentsInfo
Expand All @@ -72,24 +72,34 @@ def get_call_of_chains_retriever(documents_embedder, si, query: str):
code_source_info = source_info
git_repo = documents_embedder.get_repo_path(source_info)
documents = documents_embedder.collect_documents(source_info)

if git_repo is None:
raise ValueError("No code source info found")
with open(os.path.join(git_repo, 'ecosystem_data.txt'), 'r', encoding='utf-8') as file:
ecosystem = file.read()
ecosystem = Ecosystem[ecosystem]
if not ecosystem:
with open(os.path.join(git_repo, 'ecosystem_data.txt'), 'r', encoding='utf-8') as file:
ecosystem = file.read()
ecosystem = Ecosystem[ecosystem.upper()]
# path_to_manifest = git_repo_path.joinpath(manifest_path)

if manifest_path:
git_repo_with_manifest = git_repo.joinpath(manifest_path)
else:
git_repo_with_manifest = git_repo
coc_retriever = get_chain_of_calls_retriever(ecosystem=ecosystem,
documents=documents,
manifest_path=git_repo,
manifest_path=git_repo_with_manifest,
query=query,
code_source_info=code_source_info)
return coc_retriever


def get_transitive_code_searcher(query: str):
state: AgentMorpheusEngineState = ctx_state.get()
if state.transitive_code_searcher is None or isinstance(state.transitive_code_searcher.chain_of_calls_retriever, JavaChainOfCallsRetriever):
if (state.transitive_code_searcher is None or
isinstance(state.transitive_code_searcher.chain_of_calls_retriever, JavaChainOfCallsRetriever)):
si = state.original_input.input.image.source_info
documents_embedder = DocumentEmbedding(embedding=None)
coc_retriever = get_call_of_chains_retriever(documents_embedder, si, query)
coc_retriever = get_call_of_chains_retriever(documents_embedder, si, query, state.original_input.input.image.ecosystem , state.original_input.input.image.manifest_path)
transitive_code_searcher = TransitiveCodeSearcher(chain_of_calls_retriever=coc_retriever)
state.transitive_code_searcher = transitive_code_searcher
return state.transitive_code_searcher
Expand Down Expand Up @@ -167,7 +177,7 @@ async def package_and_function_locator(config: PackageAndFunctionLocatorToolConf
builder: Builder): # pylint: disable=unused-argument
"""
Function Locator tool used to validate package names and find function names using fuzzy matching.
Mandatory first step for code path analysis.
Mandatory first step for code path analysis.
"""

@catch_tool_errors(PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME)
Expand Down
16 changes: 14 additions & 2 deletions src/vuln_analysis/utils/document_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import copy
import json
import logging
import os
import pickle
import sys
import time
import typing
Expand All @@ -37,19 +39,24 @@
from vuln_analysis.data_models.input import SourceDocumentsInfo
from vuln_analysis.utils.data_utils import retrieve_from_cache, save_to_cache, DEFAULT_PICKLE_CACHE_DIRECTORY, DEFAULT_GIT_DIRECTORY, \
VDB_DIRECTORY, PathLike
from vuln_analysis.utils.dep_tree import Ecosystem
from vuln_analysis.utils.go_segmenters_with_methods import GoSegmenterWithMethods
from vuln_analysis.utils.python_segmenters_with_classes_methods import PythonSegmenterWithClassesMethods
from vuln_analysis.utils.java_segmenters_with_methods import JavaSegmenterWithMethods
from vuln_analysis.utils.js_extended_parser import ExtendedJavaScriptSegmenter
from vuln_analysis.utils.source_code_git_loader import SourceCodeGitLoader
from vuln_analysis.utils.git_utils import sanitize_git_url_for_path
from vuln_analysis.utils.transitive_code_searcher_tool import TransitiveCodeSearcher
from vuln_analysis.logging.loggers_factory import LoggingFactory

from vuln_analysis.utils.c_segmenter_custom import CSegmenterExtended

if typing.TYPE_CHECKING:
from langchain_core.embeddings import Embeddings # pragma: no cover




logger = LoggingFactory.get_agent_logger(__name__)

class MultiLanguageRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
Expand Down Expand Up @@ -227,7 +234,8 @@ class DocumentEmbedding:

def __init__(self, *, embedding: "Embeddings", vdb_directory: PathLike = VDB_DIRECTORY,
git_directory: PathLike = DEFAULT_GIT_DIRECTORY, chunk_size: int = 800, chunk_overlap: int = 160,
pickle_cache_directory: PathLike = DEFAULT_PICKLE_CACHE_DIRECTORY):
pickle_cache_directory: PathLike = DEFAULT_PICKLE_CACHE_DIRECTORY, ecosystem: Ecosystem = None,
manifest_path: str = None):
"""
Create a new DocumentEmbedding instance.

Expand All @@ -253,6 +261,8 @@ def __init__(self, *, embedding: "Embeddings", vdb_directory: PathLike = VDB_DIR
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._pickle_cache_directory = Path(pickle_cache_directory)
self._ecosystem = ecosystem
self._manifest_path = manifest_path

@property
def embedding(self):
Expand Down Expand Up @@ -366,7 +376,9 @@ def collect_documents(self, source_info: SourceDocumentsInfo) -> list[Document]:
clone_url=source_info.git_repo,
ref=source_info.ref,
include=source_info.include,
exclude=source_info.exclude)
exclude=source_info.exclude,
manifest_path=self._manifest_path,
ecosystem=self._ecosystem)
blob_parser = ExtendedLanguageParser()

loader = GenericLoader(blob_loader=blob_loader, blob_parser=blob_parser)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ def get_language_function_parser(ecosystem: Ecosystem, tree: DependencyTree | No
:param ecosystem: the desired programming language parser.
:param tree: the dependency tree for the ecosystem, can be None.
:return:
The right language functions parser associated to the ecosystem, if not exists, return ABC parent that
doesn't do anything
The right language functions parser associated to the ecosystem, if not exists, throw an NotImplementedError
"""
if ecosystem == Ecosystem.GO:
return GoLanguageFunctionsParser()
Expand All @@ -27,4 +26,4 @@ def get_language_function_parser(ecosystem: Ecosystem, tree: DependencyTree | No
elif ecosystem == Ecosystem.JAVA:
return JavaLanguageFunctionsParser()
else:
return LanguageFunctionsParser()
raise NotImplementedError(f"Language functions parser for {ecosystem} not implemented.")
Loading