diff --git a/python/python/async_tiff/_decoder.pyi b/python/python/async_tiff/_decoder.pyi new file mode 100644 index 0000000..81f0e93 --- /dev/null +++ b/python/python/async_tiff/_decoder.pyi @@ -0,0 +1,12 @@ +from typing import Protocol +from collections.abc import Buffer + +from .enums import CompressionMethod + +class Decoder(Protocol): + @staticmethod + def __call__(buffer: Buffer) -> Buffer: ... + +class DecoderRegistry: + def __init__(self) -> None: ... + def add(self, compression: CompressionMethod | int, decoder: Decoder) -> None: ... diff --git a/python/src/decoder.rs b/python/src/decoder.rs new file mode 100644 index 0000000..63146b1 --- /dev/null +++ b/python/src/decoder.rs @@ -0,0 +1,63 @@ +use async_tiff::decoder::{Decoder, DecoderRegistry}; +use async_tiff::error::AiocogeoError; +use bytes::Bytes; +use pyo3::exceptions::PyTypeError; +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyTuple}; +use pyo3_bytes::PyBytes; + +use crate::enums::PyCompressionMethod; + +#[pyclass(name = "DecoderRegistry")] +pub(crate) struct PyDecoderRegistry(DecoderRegistry); + +#[pymethods] +impl PyDecoderRegistry { + #[new] + fn new() -> Self { + Self(DecoderRegistry::default()) + } + + fn add(&mut self, compression: PyCompressionMethod, decoder: PyDecoder) { + self.0 + .as_mut() + .insert(compression.into(), Box::new(decoder)); + } +} + +#[derive(Debug)] +struct PyDecoder(PyObject); + +impl PyDecoder { + fn call(&self, py: Python, buffer: Bytes) -> PyResult { + let kwargs = PyDict::new(py); + kwargs.set_item(intern!(py, "buffer"), PyBytes::new(buffer))?; + let result = self.0.call(py, PyTuple::empty(py), Some(&kwargs))?; + result.extract(py) + } +} + +impl<'py> FromPyObject<'py> for PyDecoder { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if !ob.hasattr(intern!(ob.py(), "__call__"))? { + return Err(PyTypeError::new_err( + "Expected callable object for custom decoder.", + )); + } + Ok(Self(ob.clone().unbind())) + } +} + +impl Decoder for PyDecoder { + fn decode_tile( + &self, + buffer: bytes::Bytes, + _photometric_interpretation: tiff::tags::PhotometricInterpretation, + _jpeg_tables: Option<&[u8]>, + ) -> async_tiff::error::Result { + let decoded_buffer = Python::with_gil(|py| self.call(py, buffer)) + .map_err(|err| AiocogeoError::General(err.to_string()))?; + Ok(decoded_buffer.into_inner()) + } +} diff --git a/python/src/enums.rs b/python/src/enums.rs index d94ef4b..90f8afe 100644 --- a/python/src/enums.rs +++ b/python/src/enums.rs @@ -14,6 +14,18 @@ impl From for PyCompressionMethod { } } +impl From for CompressionMethod { + fn from(value: PyCompressionMethod) -> Self { + value.0 + } +} + +impl<'py> FromPyObject<'py> for PyCompressionMethod { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + Ok(Self(CompressionMethod::from_u16_exhaustive(ob.extract()?))) + } +} + impl<'py> IntoPyObject<'py> for PyCompressionMethod { type Target = PyAny; type Output = Bound<'py, PyAny>; diff --git a/python/src/lib.rs b/python/src/lib.rs index 3e6dc63..e35cec0 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,5 +1,6 @@ #![deny(clippy::undocumented_unsafe_blocks)] +mod decoder; mod enums; mod geo; mod ifd; @@ -7,6 +8,7 @@ mod tiff; use pyo3::prelude::*; +use crate::decoder::PyDecoderRegistry; use crate::geo::PyGeoKeyDirectory; use crate::ifd::PyImageFileDirectory; use crate::tiff::PyTIFF; @@ -43,6 +45,7 @@ fn _async_tiff(py: Python, m: &Bound) -> PyResult<()> { check_debug_build(py)?; m.add_wrapped(wrap_pyfunction!(___version))?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/cog.rs b/src/cog.rs index e387f8b..8ebd6f0 100644 --- a/src/cog.rs +++ b/src/cog.rs @@ -49,6 +49,7 @@ mod test { use std::io::BufReader; use std::sync::Arc; + use crate::decoder::DecoderRegistry; use crate::ObjectReader; use super::*; @@ -66,7 +67,11 @@ mod test { let cog_reader = COGReader::try_open(Box::new(reader.clone())).await.unwrap(); let ifd = &cog_reader.ifds.as_ref()[1]; - let tile = ifd.get_tile(0, 0, Box::new(reader)).await.unwrap(); + let decoder_registry = DecoderRegistry::default(); + let tile = ifd + .get_tile(0, 0, Box::new(reader), &decoder_registry) + .await + .unwrap(); std::fs::write("img.buf", tile).unwrap(); } diff --git a/src/decoder.rs b/src/decoder.rs index ad007be..a3964a3 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; +use std::fmt::Debug; use std::io::{Cursor, Read}; use bytes::Bytes; @@ -7,47 +9,138 @@ use tiff::{TiffError, TiffUnsupportedError}; use crate::error::Result; +/// A registry of decoders. +#[derive(Debug)] +pub struct DecoderRegistry(HashMap>); + +impl DecoderRegistry { + /// Create a new decoder registry with no decoders registered + pub fn new() -> Self { + Self(HashMap::new()) + } +} + +impl AsRef>> for DecoderRegistry { + fn as_ref(&self) -> &HashMap> { + &self.0 + } +} + +impl AsMut>> for DecoderRegistry { + fn as_mut(&mut self) -> &mut HashMap> { + &mut self.0 + } +} + +impl Default for DecoderRegistry { + fn default() -> Self { + let mut registry = HashMap::with_capacity(5); + registry.insert(CompressionMethod::None, Box::new(UncompressedDecoder) as _); + registry.insert(CompressionMethod::Deflate, Box::new(DeflateDecoder) as _); + registry.insert(CompressionMethod::OldDeflate, Box::new(DeflateDecoder) as _); + registry.insert(CompressionMethod::LZW, Box::new(LZWDecoder) as _); + registry.insert(CompressionMethod::ModernJPEG, Box::new(JPEGDecoder) as _); + Self(registry) + } +} + +/// A trait to decode a TIFF tile. +pub trait Decoder: Debug + Send + Sync { + fn decode_tile( + &self, + buffer: Bytes, + photometric_interpretation: PhotometricInterpretation, + jpeg_tables: Option<&[u8]>, + ) -> Result; +} + +#[derive(Debug, Clone)] +pub struct DeflateDecoder; + +impl Decoder for DeflateDecoder { + fn decode_tile( + &self, + buffer: Bytes, + _photometric_interpretation: PhotometricInterpretation, + _jpeg_tables: Option<&[u8]>, + ) -> Result { + let mut decoder = ZlibDecoder::new(Cursor::new(buffer)); + let mut buf = Vec::new(); + decoder.read_to_end(&mut buf)?; + Ok(buf.into()) + } +} + +#[derive(Debug, Clone)] +pub struct JPEGDecoder; + +impl Decoder for JPEGDecoder { + fn decode_tile( + &self, + buffer: Bytes, + photometric_interpretation: PhotometricInterpretation, + jpeg_tables: Option<&[u8]>, + ) -> Result { + decode_modern_jpeg(buffer, photometric_interpretation, jpeg_tables) + } +} + +#[derive(Debug, Clone)] +pub struct LZWDecoder; + +impl Decoder for LZWDecoder { + fn decode_tile( + &self, + buffer: Bytes, + _photometric_interpretation: PhotometricInterpretation, + _jpeg_tables: Option<&[u8]>, + ) -> Result { + // https://github.com/image-rs/image-tiff/blob/90ae5b8e54356a35e266fb24e969aafbcb26e990/src/decoder/stream.rs#L147 + let mut decoder = weezl::decode::Decoder::with_tiff_size_switch(weezl::BitOrder::Msb, 8); + let decoded = decoder.decode(&buffer).expect("failed to decode LZW data"); + Ok(decoded.into()) + } +} + +#[derive(Debug, Clone)] +pub struct UncompressedDecoder; + +impl Decoder for UncompressedDecoder { + fn decode_tile( + &self, + buffer: Bytes, + _photometric_interpretation: PhotometricInterpretation, + _jpeg_tables: Option<&[u8]>, + ) -> Result { + Ok(buffer) + } +} + // https://github.com/image-rs/image-tiff/blob/3bfb43e83e31b0da476832067ada68a82b378b7b/src/decoder/image.rs#L370 pub(crate) fn decode_tile( buf: Bytes, photometric_interpretation: PhotometricInterpretation, compression_method: CompressionMethod, // compressed_length: u64, - jpeg_tables: Option<&Vec>, + jpeg_tables: Option<&[u8]>, + decoder_registry: &DecoderRegistry, ) -> Result { - match compression_method { - CompressionMethod::None => Ok(buf), - CompressionMethod::LZW => decode_lzw(buf), - CompressionMethod::Deflate | CompressionMethod::OldDeflate => decode_deflate(buf), - CompressionMethod::ModernJPEG => { - decode_modern_jpeg(buf, photometric_interpretation, jpeg_tables) - } - method => Err(TiffError::UnsupportedError( - TiffUnsupportedError::UnsupportedCompressionMethod(method), - ) - .into()), - } -} - -fn decode_lzw(buf: Bytes) -> Result { - // https://github.com/image-rs/image-tiff/blob/90ae5b8e54356a35e266fb24e969aafbcb26e990/src/decoder/stream.rs#L147 - let mut decoder = weezl::decode::Decoder::with_tiff_size_switch(weezl::BitOrder::Msb, 8); - let decoded = decoder.decode(&buf).expect("failed to decode LZW data"); - Ok(decoded.into()) -} + let decoder = + decoder_registry + .0 + .get(&compression_method) + .ok_or(TiffError::UnsupportedError( + TiffUnsupportedError::UnsupportedCompressionMethod(compression_method), + ))?; -fn decode_deflate(buf: Bytes) -> Result { - let mut decoder = ZlibDecoder::new(Cursor::new(buf)); - let mut buf = Vec::new(); - decoder.read_to_end(&mut buf)?; - Ok(buf.into()) + decoder.decode_tile(buf, photometric_interpretation, jpeg_tables) } // https://github.com/image-rs/image-tiff/blob/3bfb43e83e31b0da476832067ada68a82b378b7b/src/decoder/image.rs#L389-L450 fn decode_modern_jpeg( buf: Bytes, photometric_interpretation: PhotometricInterpretation, - jpeg_tables: Option<&Vec>, + jpeg_tables: Option<&[u8]>, ) -> Result { // Construct new jpeg_reader wrapping a SmartReader. // @@ -76,13 +169,9 @@ fn decode_modern_jpeg( match photometric_interpretation { PhotometricInterpretation::RGB => decoder.set_color_transform(jpeg::ColorTransform::RGB), - PhotometricInterpretation::WhiteIsZero => { - decoder.set_color_transform(jpeg::ColorTransform::None) - } - PhotometricInterpretation::BlackIsZero => { - decoder.set_color_transform(jpeg::ColorTransform::None) - } - PhotometricInterpretation::TransparencyMask => { + PhotometricInterpretation::WhiteIsZero + | PhotometricInterpretation::BlackIsZero + | PhotometricInterpretation::TransparencyMask => { decoder.set_color_transform(jpeg::ColorTransform::None) } PhotometricInterpretation::CMYK => decoder.set_color_transform(jpeg::ColorTransform::CMYK), diff --git a/src/ifd.rs b/src/ifd.rs index 48fa572..ec2560c 100644 --- a/src/ifd.rs +++ b/src/ifd.rs @@ -13,7 +13,7 @@ use tiff::tags::{ use tiff::TiffError; use crate::async_reader::AsyncCursor; -use crate::decoder::decode_tile; +use crate::decoder::{decode_tile, DecoderRegistry}; use crate::error::{AiocogeoError, Result}; use crate::geo::{AffineTransform, GeoKeyDirectory, GeoKeyTag}; use crate::AsyncFileReader; @@ -681,6 +681,7 @@ impl ImageFileDirectory { x: usize, y: usize, mut reader: Box, + decoder_registry: &DecoderRegistry, ) -> Result { let range = self.get_tile_byte_range(x, y); let buf = reader.get_bytes(range).await?; @@ -688,7 +689,8 @@ impl ImageFileDirectory { buf, self.photometric_interpretation, self.compression, - self.jpeg_tables.as_ref(), + self.jpeg_tables.as_deref(), + decoder_registry, ) } @@ -697,6 +699,7 @@ impl ImageFileDirectory { x: &[usize], y: &[usize], mut reader: Box, + decoder_registry: &DecoderRegistry, ) -> Result> { assert_eq!(x.len(), y.len(), "x and y should have same len"); @@ -717,7 +720,8 @@ impl ImageFileDirectory { buf, self.photometric_interpretation, self.compression, - self.jpeg_tables.as_ref(), + self.jpeg_tables.as_deref(), + decoder_registry, )?; decoded_tiles.push(decoded); } diff --git a/src/lib.rs b/src/lib.rs index 48bb2a4..60776f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ mod async_reader; mod cog; -mod decoder; +pub mod decoder; pub mod error; pub mod geo; mod ifd;