Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions lib/pythonx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ defmodule Pythonx do

alias Pythonx.Object

@type encoder :: (term(), encoder() -> Object.t())

@doc """
Initializes the Python interpreter.

Expand Down Expand Up @@ -202,7 +204,7 @@ defmodule Pythonx do
raise ArgumentError, "expected globals keys to be strings, got: #{inspect(key)}"
end

{key, Pythonx.Encoder.encode(value)}
{key, encode!(value)}
end

code_md5 = :erlang.md5(code)
Expand Down Expand Up @@ -299,9 +301,9 @@ defmodule Pythonx do
>

"""
@spec encode!(term()) :: Object.t()
def encode!(term) do
Pythonx.Encoder.encode(term)
@spec encode!(term(), encoder()) :: Object.t()
def encode!(term, encoder \\ &Pythonx.Encoder.encode/2) do
encoder.(term, encoder)
end

@doc """
Expand Down
48 changes: 22 additions & 26 deletions lib/pythonx/encoder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ defprotocol Pythonx.Encoder do
The protocol implementation could look like this:

defimpl Pythonx.Encoder, for: Complex do
def encode(complex) do
def encode(complex, _encoder) do
{result, %{}} =
Pythonx.eval(
"""
Expand All @@ -39,7 +39,7 @@ defprotocol Pythonx.Encoder do
example, here is one possible implementation for `Explorer.DataFrame`:

defimpl Pythonx.Encoder, for: Explorer.DataFrame do
def encode(df) do
def encode(df, _encoder) do
{result, %{}} =
Pythonx.eval(
"""
Expand Down Expand Up @@ -73,38 +73,38 @@ defprotocol Pythonx.Encoder do
@doc """
A function invoked to encode the given term to `Pythonx.Object`.
"""
@spec encode(term :: term()) :: Pythonx.Object.t()
def encode(term)
@spec encode(term :: term(), Pythonx.encoder()) :: Pythonx.Object.t()
def encode(term, encoder)
end

defimpl Pythonx.Encoder, for: Pythonx.Object do
def encode(object) do
def encode(object, _encoder) do
object
end
end

defimpl Pythonx.Encoder, for: Atom do
def encode(nil) do
def encode(nil, _encoder) do
Pythonx.NIF.none_new()
end

def encode(false) do
def encode(false, _encoder) do
Pythonx.NIF.false_new()
end

def encode(true) do
def encode(true, _encoder) do
Pythonx.NIF.true_new()
end

def encode(term) do
def encode(term, _encoder) do
term
|> Atom.to_string()
|> Pythonx.NIF.unicode_from_string()
end
end

defimpl Pythonx.Encoder, for: Float do
def encode(term) do
def encode(term, _encoder) do
Pythonx.NIF.float_new(term)
end
end
Expand All @@ -113,11 +113,11 @@ defimpl Pythonx.Encoder, for: Integer do
@max_int64 2 ** 63 - 1
@min_int64 Kernel.-(2 ** 63)

def encode(term) when @min_int64 <= term and term <= @max_int64 do
def encode(term, _encoder) when @min_int64 <= term and term <= @max_int64 do
Pythonx.NIF.long_from_int64(term)
end

def encode(term) do
def encode(term, _encoder) do
# Technically we could do an object call on Python int.from_bytes,
# however given that this is a rare path (integers over 64 bits)
# and that Python C API has a specific function to create integer
Expand All @@ -127,11 +127,11 @@ defimpl Pythonx.Encoder, for: Integer do
end

defimpl Pythonx.Encoder, for: BitString do
def encode(term) when is_binary(term) do
def encode(term, _encoder) when is_binary(term) do
Pythonx.NIF.bytes_from_binary(term)
end

def encode(term) do
def encode(term, _encoder) do
raise Protocol.UndefinedError,
protocol: @protocol,
value: term,
Expand All @@ -140,29 +140,25 @@ defimpl Pythonx.Encoder, for: BitString do
end

defimpl Pythonx.Encoder, for: Map do
def encode(term) do
def encode(term, encoder) do
dict = Pythonx.NIF.dict_new()

for {key, value} <- term do
Pythonx.NIF.dict_set_item(
dict,
Pythonx.Encoder.encode(key),
Pythonx.Encoder.encode(value)
)
Pythonx.NIF.dict_set_item(dict, encoder.(key, encoder), encoder.(value, encoder))
end

dict
end
end

defimpl Pythonx.Encoder, for: Tuple do
def encode(term) do
def encode(term, encoder) do
size = tuple_size(term)

tuple = Pythonx.NIF.tuple_new(size)

for index <- 0..(size - 1)//1 do
value = Pythonx.Encoder.encode(elem(term, index))
value = encoder.(elem(term, index), encoder)
Pythonx.NIF.tuple_set_item(tuple, index, value)
end

Expand All @@ -171,7 +167,7 @@ defimpl Pythonx.Encoder, for: Tuple do
end

defimpl Pythonx.Encoder, for: List do
def encode(term) do
def encode(term, encoder) do
# Note that to compute length we need to traverse the list, but
# otherwise we cannot preallocate the Python list and we would
# need to use append (which could result in many reallocations).
Expand All @@ -180,7 +176,7 @@ defimpl Pythonx.Encoder, for: List do
list = Pythonx.NIF.list_new(size)

Enum.with_index(term, fn item, index ->
value = Pythonx.Encoder.encode(item)
value = encoder.(item, encoder)
Pythonx.NIF.list_set_item(list, index, value)
end)

Expand All @@ -189,11 +185,11 @@ defimpl Pythonx.Encoder, for: List do
end

defimpl Pythonx.Encoder, for: MapSet do
def encode(term) do
def encode(term, encoder) do
set = Pythonx.NIF.set_new()

for item <- term do
key = Pythonx.Encoder.encode(item)
key = encoder.(item, encoder)
Pythonx.NIF.set_add(set, key)
end

Expand Down
14 changes: 14 additions & 0 deletions test/pythonx_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,20 @@ defmodule PythonxTest do
object = Pythonx.encode!(1)
assert Pythonx.encode!(object) == object
end

test "custom encoder" do
# Contrived example where we encode tuples as lists.

encoder = fn
tuple, encoder when is_tuple(tuple) ->
Pythonx.Encoder.encode(Tuple.to_list(tuple), encoder)

other, encoder ->
Pythonx.Encoder.encode(other, encoder)
end

assert repr(Pythonx.encode!({1, 2}, encoder)) == "[1, 2]"
end
end

describe "decode/1" do
Expand Down