Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions library/jpeg_xl_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Modified from https://github.com/Fraetor/jxl_decode
# Added partial read support for up to 200x speedup

import os
from typing import List, Tuple

class JXLBitstream:
"""
A stream of bits with methods for easy handling.
"""

def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
self.shift = 0
self.bitstream = bytearray()
self.file = file
self.offset = offset
self.offsets = offsets
if self.offsets:
self.offset = self.offsets[0][1]
self.previous_data_len = 0
self.index = 0
self.file.seek(self.offset)

def get_bits(self, length: int = 1) -> int:
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_to_read_length = length
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(0, length)
self.bitstream.extend(self.file.read(self.partial_to_read_length))
else:
self.bitstream.extend(self.file.read(length))
bitmask = 2**length - 1
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
self.shift += length
return bits

def partial_read(self, current_length: int, length: int) -> None:
self.previous_data_len += self.offsets[self.index][2]
to_read_length = self.previous_data_len - (self.shift + current_length)
self.bitstream.extend(self.file.read(to_read_length))
current_length += to_read_length
self.partial_to_read_length -= to_read_length
self.index += 1
self.file.seek(self.offsets[self.index][1])
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(current_length, length)


def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
"""
Decodes the actual codestream.
JXL codestream specification: http://www-internal/2022/18181-1
"""

# Convert codestream to int within an object to get some handy methods.
codestream = JXLBitstream(file, offset=offset, offsets=offsets)

# Skip signature
codestream.get_bits(16)

# SizeHeader
div8 = codestream.get_bits(1)
if div8:
height = 8 * (1 + codestream.get_bits(5))
else:
distribution = codestream.get_bits(2)
match distribution:
case 0:
height = 1 + codestream.get_bits(9)
case 1:
height = 1 + codestream.get_bits(13)
case 2:
height = 1 + codestream.get_bits(18)
case 3:
height = 1 + codestream.get_bits(30)
ratio = codestream.get_bits(3)
if div8 and not ratio:
width = 8 * (1 + codestream.get_bits(5))
elif not ratio:
distribution = codestream.get_bits(2)
match distribution:
case 0:
width = 1 + codestream.get_bits(9)
case 1:
width = 1 + codestream.get_bits(13)
case 2:
width = 1 + codestream.get_bits(18)
case 3:
width = 1 + codestream.get_bits(30)
else:
match ratio:
case 1:
width = height
case 2:
width = (height * 12) // 10
case 3:
width = (height * 4) // 3
case 4:
width = (height * 3) // 2
case 5:
width = (height * 16) // 9
case 6:
width = (height * 5) // 4
case 7:
width = (height * 2) // 1
return width, height


def decode_container(file) -> Tuple[int,int]:
"""
Parses the ISOBMFF container, extracts the codestream, and decodes it.
JXL container specification: http://www-internal/2022/18181-2
"""

def parse_box(file, file_start: int) -> dict:
file.seek(file_start)
LBox = int.from_bytes(file.read(4), "big")
XLBox = None
if 1 < LBox <= 8:
raise ValueError(f"Invalid LBox at byte {file_start}.")
if LBox == 1:
file.seek(file_start + 8)
XLBox = int.from_bytes(file.read(8), "big")
if XLBox <= 16:
raise ValueError(f"Invalid XLBox at byte {file_start}.")
if XLBox:
header_length = 16
box_length = XLBox
else:
header_length = 8
if LBox == 0:
box_length = os.fstat(file.fileno()).st_size - file_start
else:
box_length = LBox
file.seek(file_start + 4)
box_type = file.read(4)
file.seek(file_start)
return {
"length": box_length,
"type": box_type,
"offset": header_length,
}

file.seek(0)
# Reject files missing required boxes. These two boxes are required to be at
# the start and contain no values, so we can manually check there presence.
# Signature box. (Redundant as has already been checked.)
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
raise ValueError("Invalid signature box.")
# File Type box.
if file.read(20) != bytes.fromhex(
"00000014 66747970 6A786C20 00000000 6A786C20"
):
raise ValueError("Invalid file type box.")

offset = 0
offsets = []
data_offset_not_found = True
container_pointer = 32
file_size = os.fstat(file.fileno()).st_size
while data_offset_not_found:
box = parse_box(file, container_pointer)
match box["type"]:
case b"jxlc":
offset = container_pointer + box["offset"]
data_offset_not_found = False
case b"jxlp":
file.seek(container_pointer + box["offset"])
index = int.from_bytes(file.read(4), "big")
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
container_pointer += box["length"]
if container_pointer >= file_size:
data_offset_not_found = False

if offsets:
offsets.sort(key=lambda i: i[0])
file.seek(0)

return decode_codestream(file, offset=offset, offsets=offsets)


def get_jxl_size(path: str) -> Tuple[int,int]:
with open(path, "rb") as file:
if file.read(2) == bytes.fromhex("FF0A"):
return decode_codestream(file)
return decode_container(file)
6 changes: 5 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,16 @@
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
from library.jpeg_xl_util import get_jxl_size

IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass

# JPEG-XL on Windows
# JPEG-XL on Linux and Windows
try:
import pillow_jxl
from library.jpeg_xl_util import get_jxl_size

IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
Expand Down Expand Up @@ -1156,6 +1158,8 @@ def cache_text_encoder_outputs(
)

def get_image_size(self, image_path):
if image_path.endswith(".jxl") or image_path.endswith(".JXL"):
return get_jxl_size(image_path)
return imagesize.get(image_path)

def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False):
Expand Down