@@ -155,6 +155,9 @@ def __init__(self):
155155
156156 # Modify existing attribute on original_module
157157 fabric_module .attribute = 101
158+ # "attribute" is only in the original_module, so it shouldn't get set in the fabric_module
159+ assert "attribute" not in fabric_module .__dict__
160+ assert fabric_module .attribute == 101 # returns it from original_module
158161 assert original_module .attribute == 101
159162
160163 # Check setattr of original_module
@@ -170,6 +173,23 @@ def __init__(self):
170173 assert linear in fabric_module .modules ()
171174 assert linear in original_module .modules ()
172175
176+ # Check monkeypatching of methods
177+ fabric_module = _FabricModule (Mock (), Mock ())
178+ original = id (fabric_module .forward )
179+ fabric_module .forward = lambda * _ : None
180+ assert id (fabric_module .forward ) != original
181+ # Check special methods
182+ assert "__repr__" in dir (fabric_module )
183+ assert "__repr__" not in fabric_module .__dict__
184+ assert "__repr__" not in _FabricModule .__dict__
185+ fabric_module .__repr__ = lambda * _ : "test"
186+ assert fabric_module .__repr__ () == "test"
187+ # needs to be monkeypatched on the class for `repr()` to change
188+ assert repr (fabric_module ) == "_FabricModule()"
189+ with mock .patch .object (_FabricModule , "__repr__" , return_value = "test" ):
190+ assert fabric_module .__repr__ () == "test"
191+ assert repr (fabric_module ) == "test"
192+
173193
174194def test_fabric_module_state_dict_access ():
175195 """Test that state_dict access passes through to the original module."""
0 commit comments