Skip to content

Commit 4f96c83

Browse files
V0XNIHILIawaelchli
andauthored
Sanitize argument-free object params before logging (#19771)
Co-authored-by: awaelchli <[email protected]>
1 parent a611de0 commit 4f96c83

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added sanitization for classes before logging them as hyperparameters ([#19771](https://github.com/Lightning-AI/pytorch-lightning/pull/19771))
13+
1214
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI ([#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))
1315

1416
- Added the ability to explicitly mark forward methods in Fabric via `_FabricModule.mark_forward_method()` ([#19690](https://github.com/Lightning-AI/pytorch-lightning/pull/19690))

src/lightning/fabric/utilities/logger.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
15+
import inspect
1416
import json
1517
from argparse import Namespace
1618
from dataclasses import asdict, is_dataclass
@@ -52,8 +54,11 @@ def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
5254
"""
5355

5456
def _sanitize_callable(val: Any) -> Any:
55-
# Give them one chance to return a value. Don't go rabbit hole of recursive call
57+
if inspect.isclass(val):
58+
# If it's a class, don't try to instantiate it, just return the name
59+
return val.__name__
5660
if callable(val):
61+
# Callables get a chance to return a name
5762
try:
5863
_val = val()
5964
if callable(_val):

tests/tests_fabric/utilities/test_logger.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class B:
9292

9393

9494
def test_sanitize_callable_params():
95-
"""Callback function are not serializiable.
95+
"""Callback functions are not serializable.
9696
9797
Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
9898
@@ -104,11 +104,21 @@ def return_something():
104104
def wrapper_something():
105105
return return_something
106106

107+
class ClassNoArgs:
108+
def __init__(self):
109+
pass
110+
111+
class ClassWithCall:
112+
def __call__(self):
113+
return "name"
114+
107115
params = Namespace(
108116
foo="bar",
109117
something=return_something,
110118
wrapper_something_wo_name=(lambda: lambda: "1"),
111119
wrapper_something=wrapper_something,
120+
class_no_args=ClassNoArgs,
121+
class_with_call=ClassWithCall,
112122
)
113123

114124
params = _convert_params(params)
@@ -118,6 +128,8 @@ def wrapper_something():
118128
assert params["something"] == "something"
119129
assert params["wrapper_something"] == "wrapper_something"
120130
assert params["wrapper_something_wo_name"] == "<lambda>"
131+
assert params["class_no_args"] == "ClassNoArgs"
132+
assert params["class_with_call"] == "ClassWithCall"
121133

122134

123135
def test_sanitize_params():

0 commit comments

Comments
 (0)