Skip to content

Commit b6331dd

Browse files
committed
Patch torch pin_memory to accept device param if needed
Signed-off-by: Charlie Truong <[email protected]>
1 parent 441ac1a commit b6331dd

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

nemo_automodel/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,26 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import functools
1415
import importlib
16+
import inspect
17+
18+
from torch.utils.data import _utils as torch_data_utils
19+
20+
# Monkey patch pin_memory to optionally accept a device argument.
21+
# The device argument was removed in some newer torch versions but we
22+
# need it for compatibility with torchdata.
23+
_original_pin_memory = torch_data_utils.pin_memory.pin_memory
24+
_original_pin_memory_sig = inspect.signature(_original_pin_memory)
25+
26+
if "device" not in _original_pin_memory_sig.parameters:
27+
28+
@functools.wraps(_original_pin_memory)
29+
def _patched_pin_memory(data, device=None):
30+
"""Patched pin_memory that accepts an optional device argument."""
31+
return _original_pin_memory(data)
32+
33+
torch_data_utils.pin_memory.pin_memory = _patched_pin_memory
1534

1635
from .package_info import __package_name__, __version__
1736

0 commit comments

Comments
 (0)