diff --git a/lib/pythonx.ex b/lib/pythonx.ex index cfd981d..a4c82e8 100644 --- a/lib/pythonx.ex +++ b/lib/pythonx.ex @@ -10,6 +10,8 @@ defmodule Pythonx do alias Pythonx.Object + @type encoder :: (term(), encoder() -> Object.t()) + @doc """ Initializes the Python interpreter. @@ -202,7 +204,7 @@ defmodule Pythonx do raise ArgumentError, "expected globals keys to be strings, got: #{inspect(key)}" end - {key, Pythonx.Encoder.encode(value)} + {key, encode!(value)} end code_md5 = :erlang.md5(code) @@ -299,9 +301,9 @@ defmodule Pythonx do > """ - @spec encode!(term()) :: Object.t() - def encode!(term) do - Pythonx.Encoder.encode(term) + @spec encode!(term(), encoder()) :: Object.t() + def encode!(term, encoder \\ &Pythonx.Encoder.encode/2) do + encoder.(term, encoder) end @doc """ diff --git a/lib/pythonx/encoder.ex b/lib/pythonx/encoder.ex index fae98fc..2ec44bc 100644 --- a/lib/pythonx/encoder.ex +++ b/lib/pythonx/encoder.ex @@ -15,7 +15,7 @@ defprotocol Pythonx.Encoder do The protocol implementation could look like this: defimpl Pythonx.Encoder, for: Complex do - def encode(complex) do + def encode(complex, _encoder) do {result, %{}} = Pythonx.eval( """ @@ -39,7 +39,7 @@ defprotocol Pythonx.Encoder do example, here is one possible implementation for `Explorer.DataFrame`: defimpl Pythonx.Encoder, for: Explorer.DataFrame do - def encode(df) do + def encode(df, _encoder) do {result, %{}} = Pythonx.eval( """ @@ -73,30 +73,30 @@ defprotocol Pythonx.Encoder do @doc """ A function invoked to encode the given term to `Pythonx.Object`. """ - @spec encode(term :: term()) :: Pythonx.Object.t() - def encode(term) + @spec encode(term :: term(), Pythonx.encoder()) :: Pythonx.Object.t() + def encode(term, encoder) end defimpl Pythonx.Encoder, for: Pythonx.Object do - def encode(object) do + def encode(object, _encoder) do object end end defimpl Pythonx.Encoder, for: Atom do - def encode(nil) do + def encode(nil, _encoder) do Pythonx.NIF.none_new() end - def encode(false) do + def encode(false, _encoder) do Pythonx.NIF.false_new() end - def encode(true) do + def encode(true, _encoder) do Pythonx.NIF.true_new() end - def encode(term) do + def encode(term, _encoder) do term |> Atom.to_string() |> Pythonx.NIF.unicode_from_string() @@ -104,7 +104,7 @@ defimpl Pythonx.Encoder, for: Atom do end defimpl Pythonx.Encoder, for: Float do - def encode(term) do + def encode(term, _encoder) do Pythonx.NIF.float_new(term) end end @@ -113,11 +113,11 @@ defimpl Pythonx.Encoder, for: Integer do @max_int64 2 ** 63 - 1 @min_int64 Kernel.-(2 ** 63) - def encode(term) when @min_int64 <= term and term <= @max_int64 do + def encode(term, _encoder) when @min_int64 <= term and term <= @max_int64 do Pythonx.NIF.long_from_int64(term) end - def encode(term) do + def encode(term, _encoder) do # Technically we could do an object call on Python int.from_bytes, # however given that this is a rare path (integers over 64 bits) # and that Python C API has a specific function to create integer @@ -127,11 +127,11 @@ defimpl Pythonx.Encoder, for: Integer do end defimpl Pythonx.Encoder, for: BitString do - def encode(term) when is_binary(term) do + def encode(term, _encoder) when is_binary(term) do Pythonx.NIF.bytes_from_binary(term) end - def encode(term) do + def encode(term, _encoder) do raise Protocol.UndefinedError, protocol: @protocol, value: term, @@ -140,15 +140,11 @@ defimpl Pythonx.Encoder, for: BitString do end defimpl Pythonx.Encoder, for: Map do - def encode(term) do + def encode(term, encoder) do dict = Pythonx.NIF.dict_new() for {key, value} <- term do - Pythonx.NIF.dict_set_item( - dict, - Pythonx.Encoder.encode(key), - Pythonx.Encoder.encode(value) - ) + Pythonx.NIF.dict_set_item(dict, encoder.(key, encoder), encoder.(value, encoder)) end dict @@ -156,13 +152,13 @@ defimpl Pythonx.Encoder, for: Map do end defimpl Pythonx.Encoder, for: Tuple do - def encode(term) do + def encode(term, encoder) do size = tuple_size(term) tuple = Pythonx.NIF.tuple_new(size) for index <- 0..(size - 1)//1 do - value = Pythonx.Encoder.encode(elem(term, index)) + value = encoder.(elem(term, index), encoder) Pythonx.NIF.tuple_set_item(tuple, index, value) end @@ -171,7 +167,7 @@ defimpl Pythonx.Encoder, for: Tuple do end defimpl Pythonx.Encoder, for: List do - def encode(term) do + def encode(term, encoder) do # Note that to compute length we need to traverse the list, but # otherwise we cannot preallocate the Python list and we would # need to use append (which could result in many reallocations). @@ -180,7 +176,7 @@ defimpl Pythonx.Encoder, for: List do list = Pythonx.NIF.list_new(size) Enum.with_index(term, fn item, index -> - value = Pythonx.Encoder.encode(item) + value = encoder.(item, encoder) Pythonx.NIF.list_set_item(list, index, value) end) @@ -189,11 +185,11 @@ defimpl Pythonx.Encoder, for: List do end defimpl Pythonx.Encoder, for: MapSet do - def encode(term) do + def encode(term, encoder) do set = Pythonx.NIF.set_new() for item <- term do - key = Pythonx.Encoder.encode(item) + key = encoder.(item, encoder) Pythonx.NIF.set_add(set, key) end diff --git a/test/pythonx_test.exs b/test/pythonx_test.exs index 61cd69f..86bb14d 100644 --- a/test/pythonx_test.exs +++ b/test/pythonx_test.exs @@ -63,6 +63,20 @@ defmodule PythonxTest do object = Pythonx.encode!(1) assert Pythonx.encode!(object) == object end + + test "custom encoder" do + # Contrived example where we encode tuples as lists. + + encoder = fn + tuple, encoder when is_tuple(tuple) -> + Pythonx.Encoder.encode(Tuple.to_list(tuple), encoder) + + other, encoder -> + Pythonx.Encoder.encode(other, encoder) + end + + assert repr(Pythonx.encode!({1, 2}, encoder)) == "[1, 2]" + end end describe "decode/1" do