Skip to content

Commit 6be931b

Browse files
Fix filesystem tool (#816)
1 parent c15810d commit 6be931b

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

ms_agent/config/config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,15 @@ def parse_args() -> Dict[str, Any]:
124124
_dict_config[key[2:]] = value
125125
return _dict_config
126126

127+
@staticmethod
128+
def safe_get_config(config: DictConfig, keys: str) -> Any:
129+
node = config
130+
for key in keys.split('.'):
131+
if not hasattr(node, key):
132+
return None
133+
node = getattr(node, key)
134+
return node
135+
127136
@staticmethod
128137
def _update_config(config: Union[DictConfig, ListConfig],
129138
extra: Dict[str, str] = None):

ms_agent/tools/filesystem_tool.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional
99

1010
import json
11+
from ms_agent.config import Config
1112
from ms_agent.llm import LLM
1213
from ms_agent.llm.utils import Message, Tool
1314
from ms_agent.tools.base import ToolBase
@@ -56,8 +57,10 @@ def __init__(self, config, **kwargs):
5657
index_dir = getattr(config, 'index_cache_dir', DEFAULT_INDEX_DIR)
5758
self.index_dir = os.path.join(self.output_dir, index_dir)
5859
self.system = self.SYSTEM_FOR_ABBREVIATIONS
59-
if hasattr(self.config.tools.file_system, 'system_for_abbreviations'):
60-
self.system = self.config.tools.file_system.system_for_abbreviations
60+
system = Config.safe_get_config(
61+
self.config, 'tools.file_system.system_for_abbreviations')
62+
if system:
63+
self.system = system
6164

6265
async def connect(self):
6366
logger.warning_once(

tests/config/test_config.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import unittest
3+
4+
from ms_agent.config import Config
5+
from omegaconf import DictConfig
6+
7+
from modelscope.utils.test_utils import test_level
8+
9+
10+
class TestConfig(unittest.TestCase):
11+
12+
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
13+
def test_safe_get_config(self):
14+
config = DictConfig(
15+
{'tools': {
16+
'file_system': {
17+
'system_for_abbreviations': 'test'
18+
}
19+
}})
20+
self.assertEqual(
21+
'test',
22+
Config.safe_get_config(
23+
config, 'tools.file_system.system_for_abbreviations'))
24+
delattr(config.tools, 'file_system')
25+
self.assertTrue(
26+
Config.safe_get_config(
27+
config, 'tools.file_system.system_for_abbreviations') is None)
28+
29+
30+
if __name__ == '__main__':
31+
unittest.main()

0 commit comments

Comments
 (0)