Skip to content

Commit 4dd6e4b

Browse files
committed
feat: refine python module check (#2952)
<!-- **Thanks for contributing to Apache Fory™.** **If this is your first time opening a PR on fory, you can refer to [CONTRIBUTING.md](https://github.com/apache/fory/blob/main/CONTRIBUTING.md).** Contribution Checklist - The **Apache Fory™** community has requirements on the naming of pr titles. You can also find instructions in [CONTRIBUTING.md](https://github.com/apache/fory/blob/main/CONTRIBUTING.md). - Apache Fory™ has a strong focus on performance. If the PR you submit will have an impact on performance, please benchmark it first and provide the benchmark result here. --> ## Why? <!-- Describe the purpose of this PR. --> ## What does this PR do? <!-- Describe the details of this PR. --> ## Related issues <!-- Is there any related issue? If this PR closes them you say say fix/closes: - #xxxx0 - #xxxx1 - Fixes #xxxx2 --> ## Does this PR introduce any user-facing change? <!-- If any user-facing interface changes, please [open an issue](https://github.com/apache/fory/issues/new/choose) describing the need to do so and update the document if necessary. Delete section if not applicable. --> - [ ] Does this PR introduce any public API change? - [ ] Does this PR introduce any binary protocol compatibility change? ## Benchmark <!-- When the PR has an impact on performance (if you don't know whether the PR will have an impact on performance, you can submit the PR first, and if it will have impact on performance, the code reviewer will explain it), be sure to attach a benchmark data here. Delete section if not applicable. -->
1 parent 4cda5d8 commit 4dd6e4b

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

python/pyfory/serializer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,12 +1476,14 @@ def write(self, buffer, value):
14761476
buffer.write_string(value.__name__)
14771477

14781478
def read(self, buffer):
1479-
mod = buffer.read_string()
1480-
mod = importlib.import_module(mod)
1481-
result = self.fory.policy.validate_module(mod.__name__)
1479+
mod_name = buffer.read_string()
1480+
result = self.fory.policy.validate_module(mod_name)
14821481
if result is not None:
1483-
mod = result
1484-
return mod
1482+
if isinstance(result, types.ModuleType):
1483+
return result
1484+
assert isinstance(result, str), f"validate_module must return module, str, or None, got {type(result)}"
1485+
mod_name = result
1486+
return importlib.import_module(mod_name)
14851487

14861488

14871489
class MappingProxySerializer(Serializer):

python/pyfory/tests/test_policy.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,35 @@ def __reduce__(self):
259259

260260
with pytest.raises(ValueError, match="Inner is blocked"):
261261
fory.deserialize(data)
262+
263+
264+
def test_validate_module():
265+
"""Test validate_module policy hook for module deserialization."""
266+
import json
267+
import collections
268+
269+
# Test 1: Return module object directly
270+
class ReturnModulePolicy(DeserializationPolicy):
271+
def validate_module(self, module_name, **kwargs):
272+
return collections
273+
274+
fory1 = Fory(ref=True, strict=False, policy=ReturnModulePolicy())
275+
data = fory1.serialize(json)
276+
assert fory1.deserialize(data) is collections
277+
278+
# Test 2: Return string to redirect import
279+
class RedirectPolicy(DeserializationPolicy):
280+
def validate_module(self, module_name, **kwargs):
281+
return "collections" if module_name == "json" else None
282+
283+
fory2 = Fory(ref=True, strict=False, policy=RedirectPolicy())
284+
assert fory2.deserialize(fory2.serialize(json)).__name__ == "collections"
285+
286+
# Test 3: Raise to block module
287+
class BlockPolicy(DeserializationPolicy):
288+
def validate_module(self, module_name, **kwargs):
289+
raise ValueError(f"Module {module_name} blocked")
290+
291+
fory3 = Fory(ref=True, strict=False, policy=BlockPolicy())
292+
with pytest.raises(ValueError, match="blocked"):
293+
fory3.deserialize(fory3.serialize(json))

0 commit comments

Comments
 (0)