@@ -50,7 +50,7 @@ def to(self, *args: Any, **kwargs: Any) -> Self:
5050 """See :meth:`torch.nn.Module.to`."""
5151 # this converts `str` device to `torch.device`
5252 device , dtype = torch ._C ._nn ._parse_to (* args , ** kwargs )[:2 ]
53- self . __update_properties ( device = device , dtype = dtype )
53+ _update_properties ( self , device = device , dtype = dtype )
5454 return super ().to (* args , ** kwargs )
5555
5656 def cuda (self , device : Optional [Union [torch .device , int ]] = None ) -> Self :
@@ -70,43 +70,46 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:
7070 device = torch .device ("cuda" , torch .cuda .current_device ())
7171 elif isinstance (device , int ):
7272 device = torch .device ("cuda" , index = device )
73- self . __update_properties ( device = device )
73+ _update_properties ( self , device = device )
7474 return super ().cuda (device = device )
7575
7676 def cpu (self ) -> Self :
7777 """See :meth:`torch.nn.Module.cpu`."""
78- self . __update_properties ( device = torch .device ("cpu" ))
78+ _update_properties ( self , device = torch .device ("cpu" ))
7979 return super ().cpu ()
8080
8181 def type (self , dst_type : Union [str , torch .dtype ]) -> Self :
8282 """See :meth:`torch.nn.Module.type`."""
83- self . __update_properties ( dtype = dst_type )
83+ _update_properties ( self , dtype = dst_type )
8484 return super ().type (dst_type = dst_type )
8585
8686 def float (self ) -> Self :
8787 """See :meth:`torch.nn.Module.float`."""
88- self . __update_properties ( dtype = torch .float )
88+ _update_properties ( self , dtype = torch .float )
8989 return super ().float ()
9090
9191 def double (self ) -> Self :
9292 """See :meth:`torch.nn.Module.double`."""
93- self . __update_properties ( dtype = torch .double )
93+ _update_properties ( self , dtype = torch .double )
9494 return super ().double ()
9595
9696 def half (self ) -> Self :
9797 """See :meth:`torch.nn.Module.half`."""
98- self . __update_properties ( dtype = torch .half )
98+ _update_properties ( self , dtype = torch .half )
9999 return super ().half ()
100100
101- def __update_properties (
102- self , device : Optional [torch .device ] = None , dtype : Optional [Union [str , torch .dtype ]] = None
103- ) -> None :
104- def apply_fn (module : Union [_DeviceDtypeModuleMixin , Module ]) -> None :
105- if not isinstance (module , _DeviceDtypeModuleMixin ):
106- return
107- if device is not None :
108- module ._device = device
109- if dtype is not None :
110- module ._dtype = dtype
111-
112- self .apply (apply_fn )
101+
102+ def _update_properties (
103+ root : torch .nn .Module , device : Optional [torch .device ] = None , dtype : Optional [Union [str , torch .dtype ]] = None
104+ ) -> None :
105+ def apply_fn (module : Union [_DeviceDtypeModuleMixin , Module ]) -> None :
106+ if not isinstance (module , _DeviceDtypeModuleMixin ):
107+ return
108+ # cannot use `module.to()` because we don't actually want to move the model in case there are multiple
109+ # devices types (such as partial meta parameters)
110+ if device is not None :
111+ module ._device = device
112+ if dtype is not None :
113+ module ._dtype = dtype
114+
115+ root .apply (apply_fn )
0 commit comments