@@ -1100,7 +1100,7 @@ def calib_func(prepared_model):
11001100 assert os .path .exists (pathname ), f"Checkpoint file does not exist: { pathname } "
11011101 if os .path .isfile (pathname ):
11021102 low_precision_checkpoint = None
1103- if pathname .endswith (".pt" ) or pathname . endswith ( ".pth" ):
1103+ if pathname .endswith (( ".pt" , ".pth" , ".bin" ) ):
11041104 low_precision_checkpoint = torch .load (pathname , weights_only = True )
11051105 elif pathname .endswith (".safetensors" ):
11061106 try :
@@ -1113,13 +1113,13 @@ def calib_func(prepared_model):
11131113 low_precision_checkpoint = safetensors .torch .load_file (pathname )
11141114 assert (
11151115 low_precision_checkpoint is not None
1116- ), f"Invalid checkpoint file: { pathname } . Should be a .pt, .pth or .safetensors file."
1116+ ), f"Invalid checkpoint file: { pathname } . Should be a .pt, .pth, .bin or .safetensors file."
11171117
11181118 quant_method = {"quant_method" : "gptq" }
11191119
11201120 elif os .path .isdir (pathname ):
11211121 low_precision_checkpoint = {}
1122- for pattern in ["*.pt" , "*.pth" ]:
1122+ for pattern in ["*.pt" , "*.pth" , "*.bin" ]:
11231123 files = list (pathlib .Path (pathname ).glob (pattern ))
11241124 if files :
11251125 for f in files :
@@ -1141,7 +1141,7 @@ def calib_func(prepared_model):
11411141 low_precision_checkpoint .update (data_f )
11421142 assert (
11431143 len (low_precision_checkpoint ) > 0
1144- ), f"Cannot find checkpoint (.pt/.pth/.safetensors) files in path { pathname } ."
1144+ ), f"Cannot find checkpoint (.pt/.pth/.bin/. safetensors) files in path { pathname } ."
11451145
11461146 try :
11471147 with open (pathname + "/config.json" ) as f :
0 commit comments