Skip to content

Commit 8c3dbe4

Browse files
authored
Merge pull request #20754 from nilchia/safetensors_dt
[25.0] Add safetensors datatype
2 parents f9032e4 + 9896e9b commit 8c3dbe4

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

lib/galaxy/config/sample/datatypes_conf.xml.sample

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,7 @@
11731173
<datatype extension="bcsp" type="galaxy.datatypes.binary:Binary" mimetype="application/octet-stream" display_in_upload="true" subclass="true" description="Binary format of k-mer hash table which is only compatible with Fairy"/>
11741174
<!-- rdeval types -->
11751175
<datatype extension="rd" type="galaxy.datatypes.binary:Binary" mimetype="application/octet-stream" display_in_upload="true" subclass="true" description="Rdeval read sketch"/>
1176+
<datatype extension="safetensors" type="galaxy.datatypes.binary:Safetensors" mimetype="application/octet-stream" display_in_upload="true" description="A simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy)" description_url="https://huggingface.co/docs/safetensors/index"/>
11761177
</registration>
11771178

11781179
<sniffers>

lib/galaxy/datatypes/binary.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4864,3 +4864,102 @@ def set_meta(self, dataset: DatasetProtocol, overwrite: bool = True, **kwd) -> N
48644864
with open(dataset.get_file_name(), "rb") as handle:
48654865
header_bytes = handle.read(8)
48664866
dataset.metadata.version = struct.unpack("<i", header_bytes[4:8])[0]
4867+
4868+
4869+
@build_sniff_from_prefix
4870+
class Safetensors(Binary):
4871+
"""
4872+
safetensors is a new simple format for storing tensors safely (as opposed to pickle) and that is still fast (zero-copy).
4873+
It provides a secure way to store and load tensors without the security risks associated with pickle-based formats.
4874+
Safetensors files consist of a JSON header followed by tensor data.
4875+
more info at: https://github.com/huggingface/safetensors
4876+
"""
4877+
4878+
file_ext = "safetensors"
4879+
4880+
def sniff_prefix(self, file_prefix: FilePrefix) -> bool:
4881+
"""
4882+
Determining if the file is in safetensors format
4883+
>>> from galaxy.datatypes.sniff import get_test_fname
4884+
>>> fname = get_test_fname('cellpose_model_safetensors.safetensors')
4885+
>>> Safetensors().sniff(fname)
4886+
True
4887+
>>> fname = get_test_fname('test_charmm.vel')
4888+
>>> Safetensors().sniff(fname)
4889+
False
4890+
"""
4891+
try:
4892+
# Safetensors files start with an 8-byte little-endian integer
4893+
# indicating the size of the JSON header
4894+
if len(file_prefix.contents_header_bytes) < 8:
4895+
return False
4896+
4897+
header_size = int.from_bytes(file_prefix.contents_header_bytes[:8], "little")
4898+
4899+
# Currently, there's a limit on the size of the header of 100MB to prevent parsing extremely large JSON headers
4900+
# In practice, safetensors headers are typically just a few KB to MB
4901+
# (containing tensor names, shapes, dtypes, and offsets - rarely exceeds 1-10MB even for large models)
4902+
# But in theory it is possible to have 100 MB header
4903+
# more info here: https://github.com/huggingface/safetensors?tab=readme-ov-file#benefits
4904+
if header_size == 0 or header_size > 10**8: # 100MB max for JSON header
4905+
return False
4906+
4907+
# Check if file is large enough to contain the full header
4908+
if file_prefix.file_size < 8 + header_size:
4909+
return False
4910+
4911+
# CRITICAL: Check if header begins with '{' character (0x7B) as per safetensors spec
4912+
# This is required by the format and helps distinguish from other binary formats
4913+
# Only check 1 byte to avoid issues with malicious header_size values
4914+
# more info here: https://github.com/huggingface/safetensors?tab=readme-ov-file#format
4915+
if file_prefix.contents_header_bytes[8] != 0x7B:
4916+
return False
4917+
4918+
# Check if header ends with '}' character (0x7D) as per safetensors spec
4919+
# This requires reading more data if header extends beyond the prefix
4920+
header_end_pos = 8 + header_size - 1
4921+
if header_end_pos < len(file_prefix.contents_header_bytes):
4922+
# Header end is within the prefix
4923+
if file_prefix.contents_header_bytes[header_end_pos] != 0x7D:
4924+
return False
4925+
else:
4926+
# Header extends beyond prefix, need to check from file
4927+
with open(file_prefix.filename, "rb") as f:
4928+
f.seek(header_end_pos)
4929+
last_header_byte = f.read(1)
4930+
if len(last_header_byte) != 1 or last_header_byte[0] != 0x7D:
4931+
return False
4932+
4933+
# Read the full header for JSON parsing
4934+
if 8 + header_size <= len(file_prefix.contents_header_bytes):
4935+
# Entire header is in the prefix
4936+
header_bytes = file_prefix.contents_header_bytes[8 : 8 + header_size]
4937+
else:
4938+
# Need to read full header from file
4939+
with open(file_prefix.filename, "rb") as f:
4940+
f.seek(8)
4941+
header_bytes = f.read(header_size)
4942+
4943+
if len(header_bytes) != header_size:
4944+
return False
4945+
4946+
# Parse the validated JSON header
4947+
header = json.loads(header_bytes.decode("utf-8"))
4948+
# check if header is a dict
4949+
if not isinstance(header, dict):
4950+
return False
4951+
# Basic validation: check if it looks like safetensors metadata
4952+
# Safetensors headers should have entries with data_offsets
4953+
has_valid_entries = False
4954+
for key, value in header.items():
4955+
if key == "__metadata__": # Special metadata key
4956+
continue
4957+
if isinstance(value, dict) and "data_offsets" in value:
4958+
has_valid_entries = True
4959+
break
4960+
4961+
return has_valid_entries
4962+
4963+
except Exception:
4964+
# Any exception during parsing means it's not a valid safetensors file
4965+
return False
76 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)