11# Copyright 2024 IBM Inc. All rights reserved
22# SPDX-License-Identifier: Apache-2.0
33
4+ from typing import Dict , Optional
5+
46import torch
7+
58from .. import cpp as fstcpp
6- from typing import Dict
7- from ..common import alloc_tensor_memory , free_tensor_memory , SafeTensorsMetadata , ALIGN , CUDA_PTR_ALIGN , paddle_loaded
8- if paddle_loaded :
9- import paddle
9+ from ..common import (
10+ ALIGN ,
11+ CUDA_PTR_ALIGN ,
12+ CUDA_VER ,
13+ SafeTensorsMetadata ,
14+ alloc_tensor_memory ,
15+ free_tensor_memory ,
16+ )
17+ from ..st_types import STDevice , STDeviceType , STDType
18+
1019
1120class GdsFileCopier :
12- def __init__ (self , metadata : SafeTensorsMetadata , device : torch .device , reader : fstcpp .gds_file_reader , debug_log : bool = False ):
21+ def __init__ (
22+ self ,
23+ metadata : SafeTensorsMetadata ,
24+ device : STDevice ,
25+ reader : fstcpp .gds_file_reader ,
26+ debug_log : bool = False ,
27+ ):
1328 self .metadata = metadata
1429 self .device = device
1530 self .reader = reader
1631 self .debug_log = debug_log
1732 self .gbuf = None
18- self .fh = 0
33+ self .fh : Optional [ fstcpp . gds_file_handle ] = None
1934 self .copy_reqs : Dict [int , int ] = {}
2035 self .aligned_length = 0
21- try :
22- if self .metadata .framework == "pytorch" :
23- cuda_vers_list = torch .version .cuda .split ('.' )
24- elif paddle_loaded and self .metadata .framework == "paddle" :
25- cuda_vers_list = paddle .version .cuda ().split ('.' )
26- cudavers = list (map (int , cuda_vers_list ))
27- # CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors
28- # Compatible with CUDA 11.x
29- self .o_direct = not (cudavers [0 ] > 12 or (cudavers [0 ] == 12 and cudavers [1 ] >= 2 ))
30- except :
31- self .o_direct = True
36+ cudavers = list (map (int , CUDA_VER .split ("." )))
37+ # CUDA 12.2 (GDS version 1.7) introduces support for non O_DIRECT file descriptors
38+ # Compatible with CUDA 11.x
39+ self .o_direct = not (
40+ cudavers [0 ] > 12 or (cudavers [0 ] == 12 and cudavers [1 ] >= 2 )
41+ )
3242
3343 def set_o_direct (self , enable : bool ):
3444 self .o_direct = enable
3545
36- def submit_io (self , use_buf_register : bool , max_copy_block_size : int )-> fstcpp .gds_device_buffer :
37- dev_is_cuda = (self .metadata .framework == "pytorch" and self .device .type == 'cuda' ) or (paddle_loaded and self .metadata .framework == "paddle" and "gpu" in self .device )
46+ def submit_io (
47+ self , use_buf_register : bool , max_copy_block_size : int
48+ ) -> fstcpp .gds_device_buffer :
49+ dev_is_cuda = (
50+ self .device .type == STDeviceType .CUDA
51+ or self .device .type == STDeviceType .GPU
52+ )
3853 self .fh = fstcpp .gds_file_handle (self .metadata .src , self .o_direct , dev_is_cuda )
3954 offset = self .metadata .header_length
4055 length = self .metadata .size_bytes - self .metadata .header_length
@@ -55,7 +70,11 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd
5570 if req_len > max_copy_block_size :
5671 req_len = max_copy_block_size
5772 if gbuf .cufile_register (count , req_len ) < 0 :
58- raise Exception ("submit_io: register_buffer failed, ptr=0x{:x}, count={}, len={}" .format (gbuf .get_base_address (), count , req_len ))
73+ raise Exception (
74+ "submit_io: register_buffer failed, ptr=0x{:x}, count={}, len={}" .format (
75+ gbuf .get_base_address (), count , req_len
76+ )
77+ )
5978 count += req_len
6079
6180 count = 0
@@ -64,40 +83,63 @@ def submit_io(self, use_buf_register: bool, max_copy_block_size: int)->fstcpp.gd
6483 if req_len > max_copy_block_size :
6584 req_len = max_copy_block_size
6685 # TODO: pass timeout so that wait_copy_tensors can recognize too slow pread()
67- req = self .reader .submit_read (self .fh , gbuf , aligned_offset + count , req_len , count , self .metadata .size_bytes )
86+ req = self .reader .submit_read (
87+ self .fh ,
88+ gbuf ,
89+ aligned_offset + count ,
90+ req_len ,
91+ count ,
92+ self .metadata .size_bytes ,
93+ )
6894 self .copy_reqs [req ] = - 1 if not use_buf_register else count
6995 count += req_len
7096 self .aligned_offset = aligned_offset
7197 self .aligned_length = aligned_length
7298 return gbuf
7399
74- def wait_io (self , gbuf : fstcpp .gds_device_buffer , dtype : torch .dtype = None , noalign : bool = False )-> Dict [str , torch .Tensor ]:
100+ def wait_io (
101+ self ,
102+ gbuf : fstcpp .gds_device_buffer ,
103+ dtype : STDType = STDType .AUTO ,
104+ noalign : bool = False ,
105+ ) -> Dict [str , torch .Tensor ]:
75106 failed = []
76- for req , c in sorted (self .copy_reqs .items (), key = lambda x :x [0 ]):
107+ for req , c in sorted (self .copy_reqs .items (), key = lambda x : x [0 ]):
77108 count = self .reader .wait_read (req )
78109 if count < 0 :
79110 failed .append (req )
80111 if c != - 1 :
81112 gbuf .cufile_deregister (c )
82- if self .fh != 0 :
113+ if self .fh is not None :
83114 del self .fh
84- self .fh = 0
115+ self .fh = None
85116 if len (failed ) > 0 :
86- raise Exception (f"wait_io: wait_gds_read failed, failed={ failed } , reqs={ self .copy_reqs } " )
117+ raise Exception (
118+ f"wait_io: wait_gds_read failed, failed={ failed } , reqs={ self .copy_reqs } "
119+ )
87120 self .copy_reqs = {}
88121 if not noalign and not self .metadata .aligned and self .aligned_length > 0 :
89122 misaligned_bytes = self .metadata .header_length % CUDA_PTR_ALIGN
90- length = 1024 * 1024 * 1024
123+ length = 1024 * 1024 * 1024
91124 tmp_gbuf = alloc_tensor_memory (length , self .device , self .metadata .framework )
92125 count = 0
93126 while count + misaligned_bytes < self .aligned_length :
94127 l = self .aligned_length - misaligned_bytes - count
95128 if l > length :
96129 l = length
97130 if self .debug_log :
98- print ("wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}" .format (gbuf .get_base_address (), misaligned_bytes , count , tmp_gbuf .get_base_address ()))
131+ print (
132+ "wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}" .format (
133+ gbuf .get_base_address (),
134+ misaligned_bytes ,
135+ count ,
136+ tmp_gbuf .get_base_address (),
137+ )
138+ )
99139 gbuf .memmove (count , misaligned_bytes + count , tmp_gbuf , l )
100140 count += l
101141 free_tensor_memory (tmp_gbuf , self .device , self .metadata .framework )
102142 self .aligned_offset += misaligned_bytes
103- return self .metadata .get_tensors (gbuf , self .device , self .aligned_offset , dtype = dtype )
143+ return self .metadata .get_tensors (
144+ gbuf , self .device , self .aligned_offset , dtype = dtype
145+ )
0 commit comments