File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed
Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import functools
1415import importlib
16+ import inspect
17+
18+ from torch .utils .data import _utils as torch_data_utils
19+
20+ # Monkey patch pin_memory to optionally accept a device argument.
21+ # The device argument was removed in some newer torch versions but we
22+ # need it for compatibility with torchdata.
23+ _original_pin_memory = torch_data_utils .pin_memory .pin_memory
24+ _original_pin_memory_sig = inspect .signature (_original_pin_memory )
25+
26+ if "device" not in _original_pin_memory_sig .parameters :
27+
28+ @functools .wraps (_original_pin_memory )
29+ def _patched_pin_memory (data , device = None ):
30+ """Patched pin_memory that accepts an optional device argument."""
31+ return _original_pin_memory (data )
32+
33+ torch_data_utils .pin_memory .pin_memory = _patched_pin_memory
1534
1635from .package_info import __package_name__ , __version__
1736
You can’t perform that action at this time.
0 commit comments