Skip to content

Commit afc5510

Browse files
nipung90facebook-github-bot
authored andcommitted
Create the function decorator to enable logging in torchrec (pytorch#3145)
Summary: Pull Request resolved: pytorch#3145 This diff creates the logging decorator that will use the torchrec logger to record the function's input, output/error and other job identifying parameters. This is also where we will perform the JK check to see if torchrec logging is enabled. Reviewed By: saumishr, kausv Differential Revision: D76294270 fbshipit-source-id: 58c174787b76ea57324fa469cf2fa465d88c9673
1 parent 53ee6b6 commit afc5510

File tree

4 files changed

+425
-0
lines changed

4 files changed

+425
-0
lines changed

torchrec/distributed/logger.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# mypy: allow-untyped-defs
9+
import functools
10+
import inspect
11+
from typing import Any, Callable, TypeVar
12+
13+
import torchrec.distributed.torchrec_logger as torchrec_logger
14+
from torchrec.distributed.torchrec_logging_handlers import TORCHREC_LOGGER_NAME
15+
from typing_extensions import ParamSpec
16+
17+
18+
__all__: list[str] = []
19+
20+
global _torchrec_logger
21+
_torchrec_logger = torchrec_logger._get_or_create_logger(TORCHREC_LOGGER_NAME)
22+
23+
_T = TypeVar("_T")
24+
_P = ParamSpec("_P")
25+
26+
27+
def _torchrec_method_logger(
28+
**wrapper_kwargs: Any,
29+
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore
30+
"""This method decorator logs the input, output, and exception of wrapped events."""
31+
32+
def decorator(func: Callable[_P, _T]): # pyre-ignore
33+
@functools.wraps(func)
34+
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
35+
msg_dict = torchrec_logger._get_msg_dict(func.__name__, **kwargs)
36+
try:
37+
## Add function input to log message
38+
msg_dict["input"] = _get_input_from_func(func, *args, **kwargs)
39+
# exceptions
40+
result = func(*args, **kwargs)
41+
except BaseException as error:
42+
msg_dict["error"] = f"{error}"
43+
_torchrec_logger.error(msg_dict)
44+
raise
45+
msg_dict["output"] = str(result)
46+
_torchrec_logger.debug(msg_dict)
47+
return result
48+
49+
return wrapper
50+
51+
return decorator
52+
53+
54+
def _get_input_from_func(
55+
func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
56+
) -> str:
57+
signature = inspect.signature(func)
58+
bound_args = signature.bind_partial(*args, **kwargs)
59+
bound_args.apply_defaults()
60+
input_vars = {param.name: param.default for param in signature.parameters.values()}
61+
for key, value in bound_args.arguments.items():
62+
if isinstance(value, (int, float)):
63+
input_vars[key] = value
64+
else:
65+
input_vars[key] = str(value)
66+
return str(input_vars)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from typing import Any
12+
from unittest import mock
13+
14+
from torchrec.distributed.logger import _get_input_from_func, _torchrec_method_logger
15+
16+
17+
class TestLogger(unittest.TestCase):
18+
def setUp(self) -> None:
19+
super().setUp()
20+
21+
# Mock torchrec_logger._get_msg_dict
22+
self.get_msg_dict_patcher = mock.patch(
23+
"torchrec.distributed.torchrec_logger._get_msg_dict"
24+
)
25+
self.mock_get_msg_dict = self.get_msg_dict_patcher.start()
26+
self.mock_get_msg_dict.return_value = {}
27+
28+
# Mock _torchrec_logger
29+
self.logger_patcher = mock.patch("torchrec.distributed.logger._torchrec_logger")
30+
self.mock_logger = self.logger_patcher.start()
31+
32+
def tearDown(self) -> None:
33+
self.get_msg_dict_patcher.stop()
34+
self.logger_patcher.stop()
35+
super().tearDown()
36+
37+
def test_get_input_from_func_no_args(self) -> None:
38+
"""Test _get_input_from_func with a function that has no arguments."""
39+
40+
def test_func() -> None:
41+
pass
42+
43+
result = _get_input_from_func(test_func)
44+
self.assertEqual(result, "{}")
45+
46+
def test_get_input_from_func_with_args(self) -> None:
47+
"""Test _get_input_from_func with a function that has positional arguments."""
48+
49+
def test_func(_a: int, _b: str) -> None:
50+
pass
51+
52+
result = _get_input_from_func(test_func, 42, "hello")
53+
self.assertEqual(result, "{'_a': 42, '_b': 'hello'}")
54+
55+
def test_get_input_from_func_with_kwargs(self) -> None:
56+
"""Test _get_input_from_func with a function that has keyword arguments."""
57+
58+
def test_func(_a: int = 0, _b: str = "default") -> None:
59+
pass
60+
61+
result = _get_input_from_func(test_func, _b="world")
62+
self.assertEqual(result, "{'_a': 0, '_b': 'world'}")
63+
64+
def test_get_input_from_func_with_args_and_kwargs(self) -> None:
65+
"""Test _get_input_from_func with a function that has both positional and keyword arguments."""
66+
67+
def test_func(
68+
_a: int, _b: str = "default", *_args: Any, **_kwargs: Any
69+
) -> None:
70+
pass
71+
72+
result = _get_input_from_func(test_func, 42, "hello", "extra", key="value")
73+
self.assertEqual(
74+
result,
75+
"{'_a': 42, '_b': 'hello', '_args': \"('extra',)\", '_kwargs': \"{'key': 'value'}\"}",
76+
)
77+
78+
def test_torchrec_method_logger_success(self) -> None:
79+
"""Test _torchrec_method_logger with a successful function execution when logging is enabled."""
80+
# Create a mock function that returns a value
81+
mock_func = mock.MagicMock(return_value="result")
82+
mock_func.__name__ = "mock_func"
83+
84+
# Apply the decorator
85+
decorated_func = _torchrec_method_logger()(mock_func)
86+
87+
# Call the decorated function
88+
result = decorated_func(42, key="value")
89+
90+
# Verify the result
91+
self.assertEqual(result, "result")
92+
93+
# Verify that _get_msg_dict was called with the correct arguments
94+
self.mock_get_msg_dict.assert_called_once_with("mock_func", key="value")
95+
96+
# Verify that the logger was called with the correct message
97+
self.mock_logger.debug.assert_called_once()
98+
msg_dict = self.mock_logger.debug.call_args[0][0]
99+
self.assertEqual(msg_dict["output"], "result")
100+
101+
def test_torchrec_method_logger_exception(self) -> None:
102+
"""Test _torchrec_method_logger with a function that raises an exception when logging is enabled."""
103+
# Create a mock function that raises an exception
104+
mock_func = mock.MagicMock(side_effect=ValueError("test error"))
105+
mock_func.__name__ = "mock_func"
106+
107+
# Apply the decorator
108+
decorated_func = _torchrec_method_logger()(mock_func)
109+
110+
# Call the decorated function and expect an exception
111+
with self.assertRaises(ValueError):
112+
decorated_func(42, key="value")
113+
114+
# Verify that _get_msg_dict was called with the correct arguments
115+
self.mock_get_msg_dict.assert_called_once_with("mock_func", key="value")
116+
117+
# Verify that the logger was called with the correct message
118+
self.mock_logger.error.assert_called_once()
119+
msg_dict = self.mock_logger.error.call_args[0][0]
120+
self.assertEqual(msg_dict["error"], "test error")
121+
122+
def test_torchrec_method_logger_with_wrapper_kwargs(self) -> None:
123+
"""Test _torchrec_method_logger with wrapper kwargs."""
124+
# Create a mock function that returns a value
125+
mock_func = mock.MagicMock(return_value="result")
126+
mock_func.__name__ = "mock_func"
127+
128+
# Apply the decorator with wrapper kwargs
129+
decorated_func = _torchrec_method_logger(custom_kwarg="value")(mock_func)
130+
131+
# Call the decorated function
132+
result = decorated_func(42, key="value")
133+
134+
# Verify the result
135+
self.assertEqual(result, "result")
136+
137+
# Verify that _get_msg_dict was called with the correct arguments
138+
self.mock_get_msg_dict.assert_called_once_with("mock_func", key="value")
139+
140+
# Verify that the logger was called with the correct message
141+
self.mock_logger.debug.assert_called_once()
142+
msg_dict = self.mock_logger.debug.call_args[0][0]
143+
self.assertEqual(msg_dict["output"], "result")
144+
145+
146+
if __name__ == "__main__":
147+
unittest.main()
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import logging
11+
import unittest
12+
from unittest import mock
13+
14+
import torch.distributed as dist
15+
16+
from torchrec.distributed.logging_handlers import _log_handlers
17+
from torchrec.distributed.torchrec_logger import (
18+
_DEFAULT_DESTINATION,
19+
_get_logging_handler,
20+
_get_msg_dict,
21+
_get_or_create_logger,
22+
)
23+
24+
25+
class TestTorchrecLogger(unittest.TestCase):
26+
def setUp(self) -> None:
27+
super().setUp()
28+
# Save the original _log_handlers to restore it after tests
29+
self.original_log_handlers = _log_handlers.copy()
30+
31+
# Create a mock logging handler
32+
self.mock_handler = mock.MagicMock(spec=logging.Handler)
33+
_log_handlers[_DEFAULT_DESTINATION] = self.mock_handler
34+
35+
# Mock print function
36+
self.print_patcher = mock.patch("builtins.print")
37+
self.mock_print = self.print_patcher.start()
38+
39+
def tearDown(self) -> None:
40+
# Restore the original _log_handlers
41+
_log_handlers.clear()
42+
_log_handlers.update(self.original_log_handlers)
43+
44+
# Stop print patcher
45+
self.print_patcher.stop()
46+
47+
super().tearDown()
48+
49+
def test_get_logging_handler(self) -> None:
50+
"""Test _get_logging_handler function."""
51+
# Test with default destination
52+
handler, name = _get_logging_handler()
53+
54+
self.assertEqual(handler, self.mock_handler)
55+
self.assertEqual(
56+
name, f"{type(self.mock_handler).__name__}-{_DEFAULT_DESTINATION}"
57+
)
58+
59+
# Test with custom destination
60+
custom_dest = "custom_dest"
61+
custom_handler = mock.MagicMock(spec=logging.Handler)
62+
_log_handlers[custom_dest] = custom_handler
63+
64+
handler, name = _get_logging_handler(custom_dest)
65+
66+
self.assertEqual(handler, custom_handler)
67+
self.assertEqual(name, f"{type(custom_handler).__name__}-{custom_dest}")
68+
69+
@mock.patch("logging.getLogger")
70+
def test_get_or_create_logger(self, mock_get_logger: mock.MagicMock) -> None:
71+
"""Test _get_or_create_logger function."""
72+
mock_logger = mock.MagicMock(spec=logging.Logger)
73+
mock_get_logger.return_value = mock_logger
74+
75+
# Test with default destination
76+
_get_or_create_logger()
77+
78+
# Verify logger was created with the correct name
79+
handler_name = f"{type(self.mock_handler).__name__}-{_DEFAULT_DESTINATION}"
80+
mock_get_logger.assert_called_once_with(f"torchrec-{handler_name}")
81+
82+
# Verify logger was configured correctly
83+
mock_logger.setLevel.assert_called_once_with(logging.DEBUG)
84+
mock_logger.addHandler.assert_called_once_with(self.mock_handler)
85+
self.assertFalse(mock_logger.propagate)
86+
87+
# Verify formatter was set on the handler
88+
self.mock_handler.setFormatter.assert_called_once()
89+
formatter = self.mock_handler.setFormatter.call_args[0][0]
90+
self.assertIsInstance(formatter, logging.Formatter)
91+
92+
# Test with custom destination
93+
mock_get_logger.reset_mock()
94+
self.mock_handler.reset_mock()
95+
96+
custom_dest = "custom_dest"
97+
custom_handler = mock.MagicMock(spec=logging.Handler)
98+
_log_handlers[custom_dest] = custom_handler
99+
100+
_get_or_create_logger(custom_dest)
101+
102+
# Verify logger was created with the correct name
103+
handler_name = f"{type(custom_handler).__name__}-{custom_dest}"
104+
mock_get_logger.assert_called_once_with(f"torchrec-{handler_name}")
105+
106+
# Verify custom handler was used
107+
mock_logger.addHandler.assert_called_once_with(custom_handler)
108+
109+
def test_get_msg_dict_without_dist(self) -> None:
110+
"""Test _get_msg_dict function without dist initialized."""
111+
# Mock dist.is_initialized to return False
112+
with mock.patch("torch.distributed.is_initialized", return_value=False):
113+
msg_dict = _get_msg_dict("test_func", kwarg1="val1")
114+
115+
# Verify msg_dict contains only func_name
116+
self.assertEqual(len(msg_dict), 1)
117+
self.assertEqual(msg_dict["func_name"], "test_func")
118+
119+
def test_get_msg_dict_with_dist(self) -> None:
120+
"""Test _get_msg_dict function with dist initialized."""
121+
# Mock dist functions
122+
with mock.patch.multiple(
123+
dist,
124+
is_initialized=mock.MagicMock(return_value=True),
125+
get_world_size=mock.MagicMock(return_value=4),
126+
get_rank=mock.MagicMock(return_value=2),
127+
):
128+
# Test with group in kwargs
129+
mock_group = mock.MagicMock()
130+
msg_dict = _get_msg_dict("test_func", group=mock_group)
131+
132+
# Verify msg_dict contains all expected keys
133+
self.assertEqual(len(msg_dict), 4)
134+
self.assertEqual(msg_dict["func_name"], "test_func")
135+
self.assertEqual(msg_dict["group"], str(mock_group))
136+
self.assertEqual(msg_dict["world_size"], "4")
137+
self.assertEqual(msg_dict["rank"], "2")
138+
139+
# Verify get_world_size and get_rank were called with the group
140+
dist.get_world_size.assert_called_once_with(mock_group)
141+
dist.get_rank.assert_called_once_with(mock_group)
142+
143+
# Test with process_group in kwargs
144+
dist.get_world_size.reset_mock()
145+
dist.get_rank.reset_mock()
146+
147+
mock_process_group = mock.MagicMock()
148+
msg_dict = _get_msg_dict("test_func", process_group=mock_process_group)
149+
150+
# Verify msg_dict contains all expected keys
151+
self.assertEqual(msg_dict["group"], str(mock_process_group))
152+
153+
# Verify get_world_size and get_rank were called with the process_group
154+
dist.get_world_size.assert_called_once_with(mock_process_group)
155+
dist.get_rank.assert_called_once_with(mock_process_group)
156+
157+
158+
if __name__ == "__main__":
159+
unittest.main()

0 commit comments

Comments
 (0)