1- # Copyright 2024 IBM Inc. All rights reserved
21# SPDX-License-Identifier: Apache-2.0
32
43import warnings
54from typing import Dict , Optional
65
76from .. import cpp as fstcpp
8- from ..common import SafeTensorsMetadata , is_gpu_found
7+ from ..common import SafeTensorsMetadata , init_logger , is_gpu_found
98from ..frameworks import FrameworkOpBase , TensorBase
109from ..st_types import Device , DeviceType , DType
1110from .base import CopierInterface
1211from .nogds import NoGdsFileCopier
1312
13+ logger = init_logger (__name__ )
14+
1415
1516class GdsFileCopier (CopierInterface ):
1617 def __init__ (
@@ -19,13 +20,11 @@ def __init__(
1920 device : Device ,
2021 reader : fstcpp .gds_file_reader ,
2122 framework : FrameworkOpBase ,
22- debug_log : bool = False ,
2323 ):
2424 self .framework = framework
2525 self .metadata = metadata
2626 self .device = device
2727 self .reader = reader
28- self .debug_log = debug_log
2928 self .gbuf = None
3029 self .fh : Optional [fstcpp .gds_file_handle ] = None
3130 self .copy_reqs : Dict [int , int ] = {}
@@ -143,15 +142,13 @@ def wait_io(
143142 l = self .aligned_length - misaligned_bytes - count
144143 if l > length :
145144 l = length
146- if self .debug_log :
147- print (
148- "wait_io: fix misalignment, src=0x{:x}, misaligned_bytes={}, count={}, tmp=0x{:x}" .format (
149- gbuf .get_base_address (),
150- misaligned_bytes ,
151- count ,
152- tmp_gbuf .get_base_address (),
153- )
154- )
145+ logger .debug (
146+ "wait_io: fix misalignment, src=0x%x, misaligned_bytes=%d, count=%d, tmp=0x%x" ,
147+ gbuf .get_base_address (),
148+ misaligned_bytes ,
149+ count ,
150+ tmp_gbuf .get_base_address (),
151+ )
155152 gbuf .memmove (count , misaligned_bytes + count , tmp_gbuf , l )
156153 count += l
157154 self .framework .free_tensor_memory (tmp_gbuf , self .device )
@@ -200,9 +197,8 @@ def construct_nogds_copier(
200197 metadata : SafeTensorsMetadata ,
201198 device : Device ,
202199 framework : FrameworkOpBase ,
203- debug_log : bool = False ,
204200 ) -> CopierInterface :
205- return NoGdsFileCopier (metadata , device , nogds_reader , framework , debug_log )
201+ return NoGdsFileCopier (metadata , device , nogds_reader , framework )
206202
207203 return construct_nogds_copier
208204
@@ -212,8 +208,7 @@ def construct_copier(
212208 metadata : SafeTensorsMetadata ,
213209 device : Device ,
214210 framework : FrameworkOpBase ,
215- debug_log : bool = False ,
216211 ) -> CopierInterface :
217- return GdsFileCopier (metadata , device , reader , framework , debug_log )
212+ return GdsFileCopier (metadata , device , reader , framework )
218213
219214 return construct_copier
0 commit comments