Skip to content

Commit 2a61cfc

Browse files
Don't block on tensor access in postprocessing (#245)
1 parent aa0fa87 commit 2a61cfc

File tree

5 files changed

+18
-1
lines changed

5 files changed

+18
-1
lines changed

lib/bumblebee/diffusion/stable_diffusion.ex

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,9 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
315315
end
316316

317317
defp client_postprocessing({outputs, _metadata}, multi?, safety_checker?) do
318+
# We use binary backend so we are not blocked by the serving computation
319+
outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend)
320+
318321
for outputs <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
319322
results =
320323
for outputs = %{image: image} <- Bumblebee.Utils.Nx.batch_to_list(outputs) do
@@ -336,7 +339,7 @@ defmodule Bumblebee.Diffusion.StableDiffusion do
336339

337340
defp zeroed(tensor) do
338341
0
339-
|> Nx.tensor(type: Nx.type(tensor))
342+
|> Nx.tensor(type: Nx.type(tensor), backend: Nx.BinaryBackend)
340343
|> Nx.broadcast(Nx.shape(tensor))
341344
end
342345

lib/bumblebee/text/question_answering.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ defmodule Bumblebee.Text.QuestionAnswering do
102102
{batch, {all_inputs, raw_inputs, multi?}}
103103
end)
104104
|> Nx.Serving.client_postprocessing(fn {outputs, _metadata}, {inputs, raw_inputs, multi?} ->
105+
# We use binary backend so we are not blocked by the serving computation
106+
inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend)
107+
outputs = Nx.backend_transfer(outputs, Nx.BinaryBackend)
108+
105109
Enum.zip_with(
106110
[raw_inputs, Utils.Nx.batch_to_list(inputs), Utils.Nx.batch_to_list(outputs)],
107111
fn [{_question_text, context_text}, inputs, outputs] ->

lib/bumblebee/text/text_embedding.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ defmodule Bumblebee.Text.TextEmbedding do
121121
{batch, multi?}
122122
end)
123123
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
124+
# We use binary backend so we are not blocked by the serving computation
125+
embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend)
126+
124127
for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
125128
%{embedding: embedding}
126129
end

lib/bumblebee/text/token_classification.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ defmodule Bumblebee.Text.TokenClassification do
8787
{batch, {all_inputs, multi?}}
8888
end)
8989
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, {inputs, multi?} ->
90+
# We use binary backend so we are not blocked by the serving computation
91+
scores = Nx.backend_transfer(scores, Nx.BinaryBackend)
92+
inputs = Nx.backend_transfer(inputs, Nx.BinaryBackend)
93+
9094
Enum.zip_with(
9195
Utils.Nx.batch_to_list(inputs),
9296
Utils.Nx.batch_to_list(scores),

lib/bumblebee/vision/image_embedding.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ defmodule Bumblebee.Vision.ImageEmbedding do
8585
{Nx.Batch.concatenate([inputs]), multi?}
8686
end)
8787
|> Nx.Serving.client_postprocessing(fn {embeddings, _metadata}, multi? ->
88+
# We use binary backend so we are not blocked by the serving computation
89+
embeddings = Nx.backend_transfer(embeddings, Nx.BinaryBackend)
90+
8891
for embedding <- Bumblebee.Utils.Nx.batch_to_list(embeddings) do
8992
%{embedding: embedding}
9093
end

0 commit comments

Comments
 (0)