@@ -31,7 +31,8 @@ class ZstdFile(_streams.BaseStream):
3131
3232    def  __init__ (
3333        self ,
34-         filename ,
34+         file ,
35+         / ,
3536        mode = "r" ,
3637        * ,
3738        level = None ,
@@ -40,7 +41,7 @@ def __init__(
4041    ):
4142        """Open a zstd compressed file in binary mode. 
4243
43-         filename  can be either an actual file name (given as a str, bytes, or 
44+         file  can be either an actual file name (given as a str, bytes, or 
4445        PathLike object), in which case the named file is opened, or it can be 
4546        an existing file object to read from or write to. 
4647
@@ -58,29 +59,23 @@ def __init__(
5859        See the function train_dict for how to train a ZstdDict on sample data. 
5960        """ 
6061        self ._fp  =  None 
61-         self ._closefp  =  False 
62+         self ._close_fp  =  False 
6263        self ._mode  =  _MODE_CLOSED 
6364
65+         if  not  isinstance (mode , str ):
66+             raise  ValueError ("mode must be a str" )
6467        # Read or write mode 
65-         if  mode  in  ("r" , "rb" ):
66-             if  not  isinstance (options , (type (None ), dict )):
67-                 raise  TypeError (
68-                     (
69-                         "In read mode (decompression), options argument " 
70-                         "should be a dict object, that represents " 
71-                         "decompression options." 
72-                     )
73-                 )
68+         if  options  is  not None  and  not  isinstance (options , dict ):
69+             raise  TypeError ("options must be a dict or None" )
70+         mode  =  mode .removesuffix ("b" )  # handle rb, wb, xb, ab 
71+         if  mode  ==  "r" :
7472            if  level  is  not None :
75-                 raise  TypeError ("level argument should only be passed when " 
76-                                 "writing." )
77-             mode_code  =  _MODE_READ 
78-         elif  mode  in  ("w" , "wb" , "a" , "ab" , "x" , "xb" ):
79-             if  not  isinstance (level , (type (None ), int )):
80-                 raise  TypeError ("level argument should be an int object." )
81-             if  not  isinstance (options , (type (None ), dict )):
82-                 raise  TypeError ("options argument should be an dict object." )
83-             mode_code  =  _MODE_WRITE 
73+                 raise  TypeError ("level is illegal in read mode" )
74+             self ._mode  =  _MODE_READ 
75+         elif  mode  in  {"w" , "a" , "x" }:
76+             if  level  is  not None  and  not  isinstance (level , int ):
77+                 raise  TypeError ("level must be int or None" )
78+             self ._mode  =  _MODE_WRITE 
8479            self ._compressor  =  ZstdCompressor (
8580                level = level , options = options , zstd_dict = zstd_dict 
8681            )
@@ -89,17 +84,15 @@ def __init__(
8984            raise  ValueError (f"Invalid mode: { mode !r}  )
9085
9186        # File object 
92-         if  isinstance (filename , (str , bytes , PathLike )):
93-             if  "b"  not  in mode :
94-                 mode  +=  "b" 
95-             self ._fp  =  io .open (filename , mode )
96-             self ._closefp  =  True 
97-         elif  hasattr (filename , "read" ) or  hasattr (filename , "write" ):
98-             self ._fp  =  filename 
87+         if  isinstance (file , (str , bytes , PathLike )):
88+             self ._fp  =  io .open (file , f'{ mode }  )
89+             self ._close_fp  =  True 
90+         elif  ((mode  ==  'r'  and  hasattr (file , "read" ))
91+                 or  (mode  !=  'r'  and  hasattr (file , "write" ))):
92+             self ._fp  =  file 
9993        else :
100-             raise  TypeError ("filename must be a str, bytes, file or PathLike " 
101-                             "object" )
102-         self ._mode  =  mode_code 
94+             raise  TypeError ("file must be a file-like object " 
95+                             "or a str, bytes, or PathLike object" )
10396
10497        if  self ._mode  ==  _MODE_READ :
10598            raw  =  _streams .DecompressReader (
@@ -114,15 +107,14 @@ def __init__(
114107    def  close (self ):
115108        """Flush and close the file. 
116109
117-         May be called more than once without error . Once the file is  
118-         closed,  any other operation on it will raise a  ValueError. 
110+         May be called multiple times . Once the file has been closed,  
111+         any other operation on it will raise ValueError. 
119112        """ 
120-         # Nop if already closed 
121113        if  self ._fp  is  None :
122114            return 
123115        try :
124116            if  self ._mode  ==  _MODE_READ :
125-                 if  hasattr (self , " _buffer" )  and   self . _buffer :
117+                 if  getattr (self , ' _buffer' ,  None ) :
126118                    self ._buffer .close ()
127119                    self ._buffer  =  None 
128120            elif  self ._mode  ==  _MODE_WRITE :
@@ -131,11 +123,11 @@ def close(self):
131123        finally :
132124            self ._mode  =  _MODE_CLOSED 
133125            try :
134-                 if  self ._closefp :
126+                 if  self ._close_fp :
135127                    self ._fp .close ()
136128            finally :
137129                self ._fp  =  None 
138-                 self ._closefp  =  False 
130+                 self ._close_fp  =  False 
139131
140132    def  write (self , data ):
141133        """Write a bytes-like object *data* to the file. 
@@ -161,9 +153,8 @@ def write(self, data):
161153    def  flush (self , mode = FLUSH_BLOCK ):
162154        """Flush remaining data to the underlying stream. 
163155
164-         The mode argument can be ZstdFile.FLUSH_BLOCK or ZstdFile.FLUSH_FRAME. 
165-         Abuse of this method will reduce compression ratio, use it only when 
166-         necessary. 
156+         The mode argument can be FLUSH_BLOCK or FLUSH_FRAME. Abuse of this 
157+         method will reduce compression ratio, use it only when necessary. 
167158
168159        If the program is interrupted afterwards, all data can be recovered. 
169160        To ensure saving to disk, also need to use os.fsync(fd). 
@@ -173,10 +164,10 @@ def flush(self, mode=FLUSH_BLOCK):
173164        if  self ._mode  ==  _MODE_READ :
174165            return 
175166        self ._check_not_closed ()
176-         if  mode  not  in ( self .FLUSH_BLOCK , self .FLUSH_FRAME ) :
177-             raise  ValueError ("mode argument wrong value, it should be  " 
178-                              "ZstdCompressor .FLUSH_FRAME or " 
179-                              "ZstdCompressor .FLUSH_BLOCK. " )
167+         if  mode  not  in { self .FLUSH_BLOCK , self .FLUSH_FRAME } :
168+             raise  ValueError ("Invalid  mode argument, expected either  " 
169+                              "ZstdFile .FLUSH_FRAME or " 
170+                              "ZstdFile .FLUSH_BLOCK" )
180171        if  self ._compressor .last_mode  ==  mode :
181172            return 
182173        # Flush zstd block/frame, and write. 
@@ -270,8 +261,7 @@ def peek(self, size=-1):
270261        return  self ._buffer .peek (size )
271262
272263    def  __next__ (self ):
273-         ret  =  self ._buffer .readline ()
274-         if  ret :
264+         if  ret  :=  self ._buffer .readline ():
275265            return  ret 
276266        raise  StopIteration 
277267
@@ -319,7 +309,8 @@ def writable(self):
319309
320310# Copied from lzma module 
321311def  open (
322-     filename ,
312+     file ,
313+     / ,
323314    mode = "rb" ,
324315    * ,
325316    level = None ,
@@ -331,9 +322,9 @@ def open(
331322):
332323    """Open a zstd compressed file in binary or text mode. 
333324
334-     filename  can be either an actual  file name (given as a str, bytes, or 
335-     PathLike object),  in which case the named file is opened, or it can be an 
336-     existing file object  to read from or write to. 
325+     file  can be either a  file name (given as a str, bytes, or PathLike object),  
326+     in which case the named file is opened, or it can be an existing file object  
327+     to read from or write to. 
337328
338329    The mode parameter can be "r", "rb" (default), "w", "wb", "x", "xb", "a", 
339330    "ab" for binary mode, or "rt", "wt", "xt", "at" for text mode. 
@@ -370,7 +361,7 @@ def open(
370361
371362    zstd_mode  =  mode .replace ("t" , "" )
372363    binary_file  =  ZstdFile (
373-         filename , zstd_mode , level = level , options = options , zstd_dict = zstd_dict 
364+         file , zstd_mode , level = level , options = options , zstd_dict = zstd_dict 
374365    )
375366
376367    if  "t"  in  mode :
0 commit comments