Skip to content

Commit 36818b1

Browse files
committed
feat:adjust uncopiable obj raise error and remove memo
1 parent cd07418 commit 36818b1

File tree

2 files changed

+55
-46
lines changed

2 files changed

+55
-46
lines changed

scrapegraphai/utils/copy.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,73 @@
11
import copy
22
from typing import Any, Dict, Optional
3+
from pydantic.v1 import BaseModel
34

45

5-
def safe_deepcopy(obj: Any, memo: Optional[Dict[int, Any]] = None) -> Any:
6+
def safe_deepcopy(obj: Any) -> Any:
67
"""
78
Attempts to create a deep copy of the object using `copy.deepcopy`
89
whenever possible. If that fails, it falls back to custom deep copy
910
logic or returns the original object.
1011
1112
Args:
1213
obj (Any): The object to be copied, which can be of any type.
13-
memo (Optional[Dict[int, Any]]): A dictionary used to track objects
14-
that have already been copied to handle circular references.
15-
If None, a new dictionary is created.
1614
1715
Returns:
1816
Any: A deep copy of the object if possible; otherwise, a shallow
1917
copy if deep copying fails; if neither is possible, the original
2018
object is returned.
2119
"""
2220

23-
if memo is None:
24-
memo = {}
25-
26-
if id(obj) in memo:
27-
return memo[id(obj)]
28-
2921
try:
22+
3023
# Try to use copy.deepcopy first
31-
return copy.deepcopy(obj, memo)
32-
except (TypeError, AttributeError):
24+
if isinstance(obj,BaseModel):
25+
# handle BaseModel because __fields_set__ need compatibility
26+
copied_obj = obj.copy(deep=True)
27+
else:
28+
copied_obj = copy.deepcopy(obj)
29+
30+
return copied_obj
31+
except (TypeError, AttributeError) as e:
3332
# If deepcopy fails, handle specific types manually
3433

3534
# Handle dictionaries
3635
if isinstance(obj, dict):
3736
new_obj = {}
38-
memo[id(obj)] = new_obj
37+
3938
for k, v in obj.items():
40-
new_obj[k] = safe_deepcopy(v, memo)
39+
new_obj[k] = safe_deepcopy(v)
4140
return new_obj
4241

4342
# Handle lists
4443
elif isinstance(obj, list):
4544
new_obj = []
46-
memo[id(obj)] = new_obj
45+
4746
for v in obj:
48-
new_obj.append(safe_deepcopy(v, memo))
47+
new_obj.append(safe_deepcopy(v))
4948
return new_obj
5049

5150
# Handle tuples (immutable, but might contain mutable objects)
5251
elif isinstance(obj, tuple):
53-
new_obj = tuple(safe_deepcopy(v, memo) for v in obj)
54-
memo[id(obj)] = new_obj
52+
new_obj = tuple(safe_deepcopy(v) for v in obj)
53+
5554
return new_obj
5655

5756
# Handle frozensets (immutable, but might contain mutable objects)
5857
elif isinstance(obj, frozenset):
59-
new_obj = frozenset(safe_deepcopy(v, memo) for v in obj)
60-
memo[id(obj)] = new_obj
58+
new_obj = frozenset(safe_deepcopy(v) for v in obj)
6159
return new_obj
6260

6361
# Handle objects with attributes
6462
elif hasattr(obj, "__dict__"):
6563
new_obj = obj.__new__(obj.__class__)
6664
for attr in obj.__dict__:
67-
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr), memo))
68-
memo[id(obj)] = new_obj
65+
setattr(new_obj, attr, safe_deepcopy(getattr(obj, attr)))
66+
6967
return new_obj
70-
68+
7169
# Attempt shallow copy as a fallback
7270
try:
7371
return copy.copy(obj)
7472
except (TypeError, AttributeError):
75-
pass
76-
77-
# If all else fails, return the original object
78-
return obj
73+
raise TypeError(f"Failed to create a deep copy obj") from e

tests/utils/copy_utils_test.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,20 @@
33

44
# Assuming the custom_deepcopy function is imported or defined above this line
55
from scrapegraphai.utils.copy import safe_deepcopy
6+
from pydantic.v1 import BaseModel
7+
from pydantic import BaseModel as BaseModelV2
68

9+
class PydantObject(BaseModel):
10+
value: int
11+
12+
class PydantObjectV2(BaseModelV2):
13+
value: int
714

815
class NormalObject:
916
def __init__(self, value):
1017
self.value = value
1118
self.nested = [1, 2, 3]
1219

13-
def __deepcopy__(self, memo):
14-
raise TypeError("Forcing fallback")
15-
1620

1721
class NonDeepcopyable:
1822
def __init__(self, value):
@@ -109,11 +113,6 @@ def test_circular_reference():
109113
assert copy_obj[0] is copy_obj
110114

111115

112-
def test_memoization():
113-
original = {"a": 1, "b": 2}
114-
memo = {}
115-
copy_obj = safe_deepcopy(original, memo)
116-
assert copy_obj is memo[id(original)]
117116

118117

119118
def test_deepcopy_object_without_dict():
@@ -154,17 +153,32 @@ def test_deepcopy_object_without_dict():
154153
assert copy_obj_item.value == original_item.value
155154
assert copy_obj_item is original_item
156155

157-
def test_memo():
158-
obj = NormalObject(10)
159-
original = {"origin": obj}
160-
memo = {id(original):obj}
161-
copy_obj = safe_deepcopy(original, memo)
162-
assert copy_obj is memo[id(original)]
163-
164156
def test_unhandled_type():
165-
original = {"origin": NonCopyableObject(10)}
157+
with pytest.raises(TypeError):
158+
original = {"origin": NonCopyableObject(10)}
159+
copy_obj = safe_deepcopy(original)
160+
161+
def test_client():
162+
llm_instance_config = {
163+
"model": "moonshot-v1-8k",
164+
"base_url": "https://api.moonshot.cn/v1",
165+
"api_key": "xxx",
166+
}
167+
168+
from langchain_community.chat_models.moonshot import MoonshotChat
169+
170+
llm_model_instance = MoonshotChat(**llm_instance_config)
171+
172+
copy_obj = safe_deepcopy(llm_model_instance)
173+
assert copy_obj
174+
175+
176+
def test_circular_reference_in_dict():
177+
original = {}
178+
original['self'] = original # Create a circular reference
166179
copy_obj = safe_deepcopy(original)
167-
assert copy_obj["origin"].value == original["origin"].value
180+
181+
# Check that the copy is a different object
168182
assert copy_obj is not original
169-
assert copy_obj["origin"] is original["origin"]
170-
assert hasattr(copy_obj, "__dict__") is False # Ensure __dict__ is not present
183+
# Check that the circular reference is maintained in the copy
184+
assert copy_obj['self'] is copy_obj

0 commit comments

Comments
 (0)