1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15-
1615import os
1716
1817from vuln_analysis .runtime_context import ctx_state
2625from langchain .docstore .document import Document
2726
2827from vuln_analysis .data_models .state import AgentMorpheusEngineState
29- from ..utils .chain_of_calls_retriever import ChainOfCallsRetriever
30- from vuln_analysis .utils .dep_tree import Ecosystem
3128from vuln_analysis .utils .document_embedding import DocumentEmbedding
29+ from ..data_models .input import SourceDocumentsInfo
30+ from ..utils .chain_of_calls_retriever_base import ChainOfCallsRetrieverBase
31+ from ..utils .chain_of_calls_retriever_factory import get_chain_of_calls_retriever
32+ from ..utils .dep_tree import Ecosystem
3233from ..utils .error_handling_decorator import catch_pipeline_errors_async , catch_tool_errors
3334from ..utils .function_name_extractor import FunctionNameExtractor
3435from ..utils .function_name_locator import FunctionNameLocator
3536
3637from vuln_analysis .logging .loggers_factory import LoggingFactory
38+ from ..utils .java_chain_of_calls_retriever import JavaChainOfCallsRetriever
39+
3740
3841PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME = "package_and_function_locator"
3942
@@ -60,29 +63,33 @@ class PackageAndFunctionLocatorToolConfig(FunctionBaseConfig, name=("%s" % PACKA
6063 Package and function locator tool used to validate package names and find function names using fuzzy matching.
6164 """
6265
63-
64- def get_call_of_chains_retriever (documents_embedder , si ):
66+ def get_call_of_chains_retriever (documents_embedder , si , query : str ):
6567 documents : list [Document ]
6668 git_repo = None
69+ code_source_info : SourceDocumentsInfo
6770 for source_info in si :
6871 if source_info .type == "code" :
72+ code_source_info = source_info
6973 git_repo = documents_embedder .get_repo_path (source_info )
7074 documents = documents_embedder .collect_documents (source_info )
7175 if git_repo is None :
7276 raise ValueError ("No code source info found" )
7377 with open (os .path .join (git_repo , 'ecosystem_data.txt' ), 'r' , encoding = 'utf-8' ) as file :
7478 ecosystem = file .read ()
7579 ecosystem = Ecosystem [ecosystem ]
76- coc_retriever = ChainOfCallsRetriever (documents = documents , ecosystem = ecosystem , manifest_path = git_repo )
80+ coc_retriever = get_chain_of_calls_retriever (ecosystem = ecosystem ,
81+ documents = documents ,
82+ manifest_path = git_repo ,
83+ query = query ,
84+ code_source_info = code_source_info )
7785 return coc_retriever
7886
79-
80- def get_transitive_code_searcher ():
87+ def get_transitive_code_searcher (query : str ):
8188 state : AgentMorpheusEngineState = ctx_state .get ()
82- if state .transitive_code_searcher is None :
89+ if state .transitive_code_searcher is None or isinstance ( state . transitive_code_searcher . chain_of_calls_retriever , JavaChainOfCallsRetriever ) :
8390 si = state .original_input .input .image .source_info
8491 documents_embedder = DocumentEmbedding (embedding = None )
85- coc_retriever = get_call_of_chains_retriever (documents_embedder , si )
92+ coc_retriever = get_call_of_chains_retriever (documents_embedder , si , query )
8693 transitive_code_searcher = TransitiveCodeSearcher (chain_of_calls_retriever = coc_retriever )
8794 state .transitive_code_searcher = transitive_code_searcher
8895 return state .transitive_code_searcher
@@ -108,16 +115,22 @@ async def transitive_search(config: TransitiveCodeSearchToolConfig,
108115 @catch_tool_errors (TRANSITIVE_CODE_SEARCH_TOOL_NAME )
109116 async def _arun (query : str ) -> tuple :
110117 transitive_code_searcher : TransitiveCodeSearcher
111- transitive_code_searcher = get_transitive_code_searcher ()
118+ transitive_code_searcher = get_transitive_code_searcher (query )
112119 result = transitive_code_searcher .search (query )
113120 return result
114121
115122 yield FunctionInfo .from_fn (
116123 _arun ,
117124 description = ("""
118- Checks if a function from a package is reachable from application code through the call chain.
119- Input format: 'package_name,function_name'.
120- Example: 'urllib,parse'.
125+ Checks if a function from a package is reachable from application code through the call chain.
126+ Make sure the input format is matching exactly one of the following formats:
127+
128+ Input format 1: 'package_name,function_name'.
129+ Example 1: 'urllib,parse'.
130+
131+ Input format 2(java): 'maven_gav,class_name.function_name'.
132+ Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.
133+
121134 Returns: (is_reachable: bool, call_hierarchy_path: list).
122135""" ))
123136
@@ -131,9 +144,9 @@ async def functions_usage_search(config: CallingFunctionNameExtractorToolConfig,
131144 """
132145 @catch_tool_errors (FUNCTION_NAME_EXTRACTOR_TOOL_NAME )
133146 async def _arun (query : str ) -> list :
134- coc_retriever : ChainOfCallsRetriever
147+ coc_retriever : ChainOfCallsRetrieverBase
135148 transitive_code_searcher : TransitiveCodeSearcher
136- transitive_code_searcher = get_transitive_code_searcher ()
149+ transitive_code_searcher = get_transitive_code_searcher (query )
137150 coc_retriever = transitive_code_searcher .chain_of_calls_retriever
138151 function_name_extractor = FunctionNameExtractor (coc_retriever )
139152 result = function_name_extractor .fetch_list (query )
@@ -154,33 +167,38 @@ async def package_and_function_locator(config: PackageAndFunctionLocatorToolConf
154167 builder : Builder ): # pylint: disable=unused-argument
155168 """
156169 Function Locator tool used to validate package names and find function names using fuzzy matching.
157- Mandatory first step for code path analysis.
170+ Mandatory first step for code path analysis.
158171 """
159172
160173 @catch_tool_errors (PACKAGE_AND_FUNCTION_LOCATOR_TOOL_NAME )
161174 async def _arun (query : str ) -> dict :
162- coc_retriever : ChainOfCallsRetriever
175+ coc_retriever : ChainOfCallsRetrieverBase
163176 transitive_code_searcher : TransitiveCodeSearcher
164- transitive_code_searcher = get_transitive_code_searcher ()
177+ transitive_code_searcher = get_transitive_code_searcher (query )
165178 coc_retriever = transitive_code_searcher .chain_of_calls_retriever
166179 locator = FunctionNameLocator (coc_retriever )
167180 result = await locator .locate_functions (query )
168181 pkg_msg = "Package is valid."
169- if not locator .is_package_valid and not locator .is_std_package :
170- pkg_msg = "Package is not valid."
171-
172-
182+ if not locator .is_package_valid and not locator .is_std_package :
183+ pkg_msg = "Package is not valid."
184+
173185 return {
174186 "ecosystem" : coc_retriever .ecosystem .name ,
175- "package_msg" : pkg_msg ,
187+ "package_msg" : pkg_msg ,
176188 "result" : result
177189 }
178190
179191 yield FunctionInfo .from_fn (
180192 _arun ,
181193 description = ("""
182194 Mandatory first step for code path analysis. Validates package names, locates functions using fuzzy matching, and provides ecosystem type (GO/Python/Java/JavaScript/C/C++).
183- Input format: 'package_name,function_name' or 'package_name,class_name.method_name'
184- Example: 'libxml2,xmlParseDocument'
195+ Make sure the input format is matching exactly one of the following formats:
196+
197+ Input format 1: 'package_name,function_name' or 'package_name,class_name.method_name'.
198+ Example 1: 'libxml2,xmlParseDocument'.
199+
200+ Input format 2(java): 'maven_gav,class_name.method_name'.
201+ Example 2(java): 'commons-beanutils:commons-beanutils:1.0.0,PropertyUtilsBean.setSimpleProperty'.
202+
185203 Returns: {'ecosystem': str, 'package_msg': str, 'result': [function_names]}.
186204""" ))
0 commit comments