Skip to content

Commit 3bca4c5

Browse files
committed
Handle nil attributes when loading generation config
1 parent e2f5b4c commit 3bca4c5

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

lib/bumblebee/text/generation_config.ex

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,10 @@ defmodule Bumblebee.Text.GenerationConfig do
287287

288288
opts =
289289
convert!(data,
290-
max_new_tokens: {"max_new_tokens", number()},
291-
min_new_tokens: {"min_new_tokens", number()},
292-
max_length: {"max_length", number()},
293-
min_length: {"min_length", number()},
290+
max_new_tokens: {"max_new_tokens", optional(number())},
291+
min_new_tokens: {"min_new_tokens", optional(number())},
292+
max_length: {"max_length", optional(number())},
293+
min_length: {"min_length", optional(number())},
294294
decoder_start_token_id: {"decoder_start_token_id", optional(number())},
295295
bos_token_id: {"bos_token_id", optional(number())},
296296
eos_token_id: {"eos_token_id", optional(number())},
@@ -306,10 +306,11 @@ defmodule Bumblebee.Text.GenerationConfig do
306306
data
307307
|> convert!(
308308
sample: {"do_sample", boolean()},
309-
top_k: {"top_k", number()},
310-
top_p: {"top_p", number()},
311-
alpha: {"penalty_alpha", number()}
309+
top_k: {"top_k", optional(number())},
310+
top_p: {"top_p", optional(number())},
311+
alpha: {"penalty_alpha", optional(number())}
312312
)
313+
|> Enum.reject(fn {_key, value} -> value == nil end)
313314
|> Map.new()
314315
|> case do
315316
%{sample: true} = opts ->

0 commit comments

Comments
 (0)