1+ import copy
2+ import logging
13import subprocess
24import tempfile
35
6+ from typing import Any
7+
8+ from memos .configs .mem_cube import GeneralMemCubeConfig
9+
10+
11+ logger = logging .getLogger (__name__ )
12+
413
514def download_repo (repo : str , base_url : str , dir : str | None = None ) -> str :
615 """Download a repository from a remote source.
@@ -22,3 +31,96 @@ def download_repo(repo: str, base_url: str, dir: str | None = None) -> str:
2231 subprocess .run (["git" , "clone" , repo_url , dir ], check = True )
2332
2433 return dir
34+
35+
36+ def merge_config_with_default (
37+ existing_config : GeneralMemCubeConfig , default_config : GeneralMemCubeConfig
38+ ) -> GeneralMemCubeConfig :
39+ """
40+ Merge existing cube config with default config, preserving critical fields.
41+
42+ This method updates general configuration fields (like API keys, model parameters)
43+ while preserving critical user-specific fields (like user_id, cube_id, graph_db settings).
44+
45+ Args:
46+ existing_config (GeneralMemCubeConfig): The existing cube configuration loaded from file
47+ default_config (GeneralMemCubeConfig): The default configuration to merge from
48+
49+ Returns:
50+ GeneralMemCubeConfig: Merged configuration
51+ """
52+
53+ def deep_merge_dicts (
54+ existing : dict [str , Any ], default : dict [str , Any ], preserve_keys : set [str ] | None = None
55+ ) -> dict [str , Any ]:
56+ """Recursively merge dictionaries, preserving specified keys from existing dict."""
57+ if preserve_keys is None :
58+ preserve_keys = set ()
59+
60+ result = copy .deepcopy (existing )
61+
62+ for key , default_value in default .items ():
63+ if key in preserve_keys :
64+ # Preserve existing value for critical keys
65+ continue
66+
67+ if key in result and isinstance (result [key ], dict ) and isinstance (default_value , dict ):
68+ # Recursively merge nested dictionaries
69+ result [key ] = deep_merge_dicts (result [key ], default_value , preserve_keys )
70+ elif key not in result or result [key ] is None :
71+ # Use default value if key doesn't exist or is None
72+ result [key ] = copy .deepcopy (default_value )
73+ # For non-dict values, keep existing value unless it's None
74+
75+ return result
76+
77+ # Convert configs to dictionaries
78+ existing_dict = existing_config .model_dump (mode = "json" )
79+ default_dict = default_config .model_dump (mode = "json" )
80+
81+ # Merge text_mem config
82+ if "text_mem" in existing_dict and "text_mem" in default_dict :
83+ existing_text_config = existing_dict ["text_mem" ].get ("config" , {})
84+ default_text_config = default_dict ["text_mem" ].get ("config" , {})
85+
86+ # Handle nested graph_db config specially
87+ if "graph_db" in existing_text_config and "graph_db" in default_text_config :
88+ existing_graph_config = existing_text_config ["graph_db" ].get ("config" , {})
89+ default_graph_config = default_text_config ["graph_db" ].get ("config" , {})
90+
91+ # Merge graph_db config, preserving critical keys
92+ merged_graph_config = deep_merge_dicts (
93+ existing_graph_config ,
94+ default_graph_config ,
95+ preserve_keys = {"uri" , "user" , "password" , "db_name" , "auto_create" },
96+ )
97+
98+ # Update the configs
99+ existing_text_config ["graph_db" ]["config" ] = merged_graph_config
100+ default_text_config ["graph_db" ]["config" ] = merged_graph_config
101+
102+ # Merge other text_mem config fields
103+ merged_text_config = deep_merge_dicts (existing_text_config , default_text_config )
104+ existing_dict ["text_mem" ]["config" ] = merged_text_config
105+
106+ # Merge act_mem config
107+ if "act_mem" in existing_dict and "act_mem" in default_dict :
108+ existing_act_config = existing_dict ["act_mem" ].get ("config" , {})
109+ default_act_config = default_dict ["act_mem" ].get ("config" , {})
110+ merged_act_config = deep_merge_dicts (existing_act_config , default_act_config )
111+ existing_dict ["act_mem" ]["config" ] = merged_act_config
112+
113+ # Merge para_mem config
114+ if "para_mem" in existing_dict and "para_mem" in default_dict :
115+ existing_para_config = existing_dict ["para_mem" ].get ("config" , {})
116+ default_para_config = default_dict ["para_mem" ].get ("config" , {})
117+ merged_para_config = deep_merge_dicts (existing_para_config , default_para_config )
118+ existing_dict ["para_mem" ]["config" ] = merged_para_config
119+
120+ # Create new config from merged dictionary
121+ merged_config = GeneralMemCubeConfig .model_validate (existing_dict )
122+ logger .info (
123+ f"Merged cube config for user { merged_config .user_id } , cube { merged_config .cube_id } "
124+ )
125+
126+ return merged_config
0 commit comments