Skip to content

Commit 2ae4c34

Browse files
Support custom encoder function in encode! (#2)
1 parent 9bd084c commit 2ae4c34

File tree

3 files changed

+42
-30
lines changed

3 files changed

+42
-30
lines changed

lib/pythonx.ex

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ defmodule Pythonx do
1010

1111
alias Pythonx.Object
1212

13+
@type encoder :: (term(), encoder() -> Object.t())
14+
1315
@doc """
1416
Initializes the Python interpreter.
1517
@@ -202,7 +204,7 @@ defmodule Pythonx do
202204
raise ArgumentError, "expected globals keys to be strings, got: #{inspect(key)}"
203205
end
204206

205-
{key, Pythonx.Encoder.encode(value)}
207+
{key, encode!(value)}
206208
end
207209

208210
code_md5 = :erlang.md5(code)
@@ -299,9 +301,9 @@ defmodule Pythonx do
299301
>
300302
301303
"""
302-
@spec encode!(term()) :: Object.t()
303-
def encode!(term) do
304-
Pythonx.Encoder.encode(term)
304+
@spec encode!(term(), encoder()) :: Object.t()
305+
def encode!(term, encoder \\ &Pythonx.Encoder.encode/2) do
306+
encoder.(term, encoder)
305307
end
306308

307309
@doc """

lib/pythonx/encoder.ex

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ defprotocol Pythonx.Encoder do
1515
The protocol implementation could look like this:
1616
1717
defimpl Pythonx.Encoder, for: Complex do
18-
def encode(complex) do
18+
def encode(complex, _encoder) do
1919
{result, %{}} =
2020
Pythonx.eval(
2121
"""
@@ -39,7 +39,7 @@ defprotocol Pythonx.Encoder do
3939
example, here is one possible implementation for `Explorer.DataFrame`:
4040
4141
defimpl Pythonx.Encoder, for: Explorer.DataFrame do
42-
def encode(df) do
42+
def encode(df, _encoder) do
4343
{result, %{}} =
4444
Pythonx.eval(
4545
"""
@@ -73,38 +73,38 @@ defprotocol Pythonx.Encoder do
7373
@doc """
7474
A function invoked to encode the given term to `Pythonx.Object`.
7575
"""
76-
@spec encode(term :: term()) :: Pythonx.Object.t()
77-
def encode(term)
76+
@spec encode(term :: term(), Pythonx.encoder()) :: Pythonx.Object.t()
77+
def encode(term, encoder)
7878
end
7979

8080
defimpl Pythonx.Encoder, for: Pythonx.Object do
81-
def encode(object) do
81+
def encode(object, _encoder) do
8282
object
8383
end
8484
end
8585

8686
defimpl Pythonx.Encoder, for: Atom do
87-
def encode(nil) do
87+
def encode(nil, _encoder) do
8888
Pythonx.NIF.none_new()
8989
end
9090

91-
def encode(false) do
91+
def encode(false, _encoder) do
9292
Pythonx.NIF.false_new()
9393
end
9494

95-
def encode(true) do
95+
def encode(true, _encoder) do
9696
Pythonx.NIF.true_new()
9797
end
9898

99-
def encode(term) do
99+
def encode(term, _encoder) do
100100
term
101101
|> Atom.to_string()
102102
|> Pythonx.NIF.unicode_from_string()
103103
end
104104
end
105105

106106
defimpl Pythonx.Encoder, for: Float do
107-
def encode(term) do
107+
def encode(term, _encoder) do
108108
Pythonx.NIF.float_new(term)
109109
end
110110
end
@@ -113,11 +113,11 @@ defimpl Pythonx.Encoder, for: Integer do
113113
@max_int64 2 ** 63 - 1
114114
@min_int64 Kernel.-(2 ** 63)
115115

116-
def encode(term) when @min_int64 <= term and term <= @max_int64 do
116+
def encode(term, _encoder) when @min_int64 <= term and term <= @max_int64 do
117117
Pythonx.NIF.long_from_int64(term)
118118
end
119119

120-
def encode(term) do
120+
def encode(term, _encoder) do
121121
# Technically we could do an object call on Python int.from_bytes,
122122
# however given that this is a rare path (integers over 64 bits)
123123
# and that Python C API has a specific function to create integer
@@ -127,11 +127,11 @@ defimpl Pythonx.Encoder, for: Integer do
127127
end
128128

129129
defimpl Pythonx.Encoder, for: BitString do
130-
def encode(term) when is_binary(term) do
130+
def encode(term, _encoder) when is_binary(term) do
131131
Pythonx.NIF.bytes_from_binary(term)
132132
end
133133

134-
def encode(term) do
134+
def encode(term, _encoder) do
135135
raise Protocol.UndefinedError,
136136
protocol: @protocol,
137137
value: term,
@@ -140,29 +140,25 @@ defimpl Pythonx.Encoder, for: BitString do
140140
end
141141

142142
defimpl Pythonx.Encoder, for: Map do
143-
def encode(term) do
143+
def encode(term, encoder) do
144144
dict = Pythonx.NIF.dict_new()
145145

146146
for {key, value} <- term do
147-
Pythonx.NIF.dict_set_item(
148-
dict,
149-
Pythonx.Encoder.encode(key),
150-
Pythonx.Encoder.encode(value)
151-
)
147+
Pythonx.NIF.dict_set_item(dict, encoder.(key, encoder), encoder.(value, encoder))
152148
end
153149

154150
dict
155151
end
156152
end
157153

158154
defimpl Pythonx.Encoder, for: Tuple do
159-
def encode(term) do
155+
def encode(term, encoder) do
160156
size = tuple_size(term)
161157

162158
tuple = Pythonx.NIF.tuple_new(size)
163159

164160
for index <- 0..(size - 1)//1 do
165-
value = Pythonx.Encoder.encode(elem(term, index))
161+
value = encoder.(elem(term, index), encoder)
166162
Pythonx.NIF.tuple_set_item(tuple, index, value)
167163
end
168164

@@ -171,7 +167,7 @@ defimpl Pythonx.Encoder, for: Tuple do
171167
end
172168

173169
defimpl Pythonx.Encoder, for: List do
174-
def encode(term) do
170+
def encode(term, encoder) do
175171
# Note that to compute length we need to traverse the list, but
176172
# otherwise we cannot preallocate the Python list and we would
177173
# need to use append (which could result in many reallocations).
@@ -180,7 +176,7 @@ defimpl Pythonx.Encoder, for: List do
180176
list = Pythonx.NIF.list_new(size)
181177

182178
Enum.with_index(term, fn item, index ->
183-
value = Pythonx.Encoder.encode(item)
179+
value = encoder.(item, encoder)
184180
Pythonx.NIF.list_set_item(list, index, value)
185181
end)
186182

@@ -189,11 +185,11 @@ defimpl Pythonx.Encoder, for: List do
189185
end
190186

191187
defimpl Pythonx.Encoder, for: MapSet do
192-
def encode(term) do
188+
def encode(term, encoder) do
193189
set = Pythonx.NIF.set_new()
194190

195191
for item <- term do
196-
key = Pythonx.Encoder.encode(item)
192+
key = encoder.(item, encoder)
197193
Pythonx.NIF.set_add(set, key)
198194
end
199195

test/pythonx_test.exs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,20 @@ defmodule PythonxTest do
6363
object = Pythonx.encode!(1)
6464
assert Pythonx.encode!(object) == object
6565
end
66+
67+
test "custom encoder" do
68+
# Contrived example where we encode tuples as lists.
69+
70+
encoder = fn
71+
tuple, encoder when is_tuple(tuple) ->
72+
Pythonx.Encoder.encode(Tuple.to_list(tuple), encoder)
73+
74+
other, encoder ->
75+
Pythonx.Encoder.encode(other, encoder)
76+
end
77+
78+
assert repr(Pythonx.encode!({1, 2}, encoder)) == "[1, 2]"
79+
end
6680
end
6781

6882
describe "decode/1" do

0 commit comments

Comments
 (0)