@@ -4864,3 +4864,102 @@ def set_meta(self, dataset: DatasetProtocol, overwrite: bool = True, **kwd) -> N
48644864 with open (dataset .get_file_name (), "rb" ) as handle :
48654865 header_bytes = handle .read (8 )
48664866 dataset .metadata .version = struct .unpack ("<i" , header_bytes [4 :8 ])[0 ]
4867+
4868+
4869+ @build_sniff_from_prefix
4870+ class Safetensors (Binary ):
4871+ """
4872+ safetensors is a new simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy).
4873+ It provides a secure way to store and load tensors without the security risks associated with pickle-based formats.
4874+ Safetensors files consist of a JSON header followed by tensor data.
4875+ more info at: https://github.com/huggingface/safetensors
4876+ """
4877+
4878+ file_ext = "safetensors"
4879+
4880+ def sniff_prefix (self , file_prefix : FilePrefix ) -> bool :
4881+ """
4882+ Determining if the file is in safetensors format
4883+ >>> from galaxy.datatypes.sniff import get_test_fname
4884+ >>> fname = get_test_fname('cellpose_model_safetensors.safetensors')
4885+ >>> Safetensors().sniff(fname)
4886+ True
4887+ >>> fname = get_test_fname('test_charmm.vel')
4888+ >>> Safetensors().sniff(fname)
4889+ False
4890+ """
4891+ try :
4892+ # Safetensors files start with an 8-byte little-endian integer
4893+ # indicating the size of the JSON header
4894+ if len (file_prefix .contents_header_bytes ) < 8 :
4895+ return False
4896+
4897+ header_size = int .from_bytes (file_prefix .contents_header_bytes [:8 ], "little" )
4898+
4899+ # Currently, there's a limit on the size of the header of 100MB to prevent parsing extremely large JSON headers
4900+ # In practice, safetensors headers are typically just a few KB to MB
4901+ # (containing tensor names, shapes, dtypes, and offsets - rarely exceeds 1-10MB even for large models)
4902+ # But in theory it is possible to have 100 MB header
4903+ # more info here: https://github.com/huggingface/safetensors?tab=readme-ov-file#benefits
4904+ if header_size == 0 or header_size > 10 ** 8 : # 100MB max for JSON header
4905+ return False
4906+
4907+ # Check if file is large enough to contain the full header
4908+ if file_prefix .file_size < 8 + header_size :
4909+ return False
4910+
4911+ # CRITICAL: Check if header begins with '{' character (0x7B) as per safetensors spec
4912+ # This is required by the format and helps distinguish from other binary formats
4913+ # Only check 1 byte to avoid issues with malicious header_size values
4914+ # more info here: https://github.com/huggingface/safetensors?tab=readme-ov-file#format
4915+ if file_prefix .contents_header_bytes [8 ] != 0x7B :
4916+ return False
4917+
4918+ # Check if header ends with '}' character (0x7D) as per safetensors spec
4919+ # This requires reading more data if header extends beyond the prefix
4920+ header_end_pos = 8 + header_size - 1
4921+ if header_end_pos < len (file_prefix .contents_header_bytes ):
4922+ # Header end is within the prefix
4923+ if file_prefix .contents_header_bytes [header_end_pos ] != 0x7D :
4924+ return False
4925+ else :
4926+ # Header extends beyond prefix, need to check from file
4927+ with open (file_prefix .filename , "rb" ) as f :
4928+ f .seek (header_end_pos )
4929+ last_header_byte = f .read (1 )
4930+ if len (last_header_byte ) != 1 or last_header_byte [0 ] != 0x7D :
4931+ return False
4932+
4933+ # Read the full header for JSON parsing
4934+ if 8 + header_size <= len (file_prefix .contents_header_bytes ):
4935+ # Entire header is in the prefix
4936+ header_bytes = file_prefix .contents_header_bytes [8 : 8 + header_size ]
4937+ else :
4938+ # Need to read full header from file
4939+ with open (file_prefix .filename , "rb" ) as f :
4940+ f .seek (8 )
4941+ header_bytes = f .read (header_size )
4942+
4943+ if len (header_bytes ) != header_size :
4944+ return False
4945+
4946+ # Parse the validated JSON header
4947+ header = json .loads (header_bytes .decode ("utf-8" ))
4948+ # check if header is a dict
4949+ if not isinstance (header , dict ):
4950+ return False
4951+ # Basic validation: check if it looks like safetensors metadata
4952+ # Safetensors headers should have entries with data_offsets
4953+ has_valid_entries = False
4954+ for key , value in header .items ():
4955+ if key == "__metadata__" : # Special metadata key
4956+ continue
4957+ if isinstance (value , dict ) and "data_offsets" in value :
4958+ has_valid_entries = True
4959+ break
4960+
4961+ return has_valid_entries
4962+
4963+ except Exception :
4964+ # Any exception during parsing means it's not a valid safetensors file
4965+ return False
0 commit comments