Skip to content

Commit e3710a3

Browse files
Fix OOM error in gr.load (#12928)
* Fix OOM error in gr.load * add changeset * Fix flaky test * Serial * Fix client tests * Add code * Back out js/css/ load from hub --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent ca84f3e commit e3710a3

File tree

6 files changed

+51
-58
lines changed

6 files changed

+51
-58
lines changed

.changeset/busy-heads-bet.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": patch
3+
---
4+
5+
fix:Fix OOM error in gr.load

client/python/test/test_client.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -282,34 +282,26 @@ def test_raises_exception(self, calculator_demo):
282282
def test_job_output_video(self, video_component):
283283
with connect(video_component) as client:
284284
job = client.submit(
285-
{
286-
"video": handle_file(
287-
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
288-
)
289-
},
290-
fn_index=0,
285+
handle_file(
286+
"https://huggingface.co/datasets/freddyaboulton/bucket/resolve/main/ProgressNotifications.mp4"
287+
),
288+
api_name="/predict",
291289
)
292-
assert Path(job.result()["video"]).exists()
290+
assert Path(job.result()).exists()
293291
assert (
294-
Path(DEFAULT_TEMP_DIR).resolve()
295-
in Path(job.result()["video"]).resolve().parents
292+
Path(DEFAULT_TEMP_DIR).resolve() in Path(job.result()).resolve().parents
296293
)
297294

298295
temp_dir = tempfile.mkdtemp()
299296
with connect(video_component, download_files=temp_dir) as client:
300297
job = client.submit(
301-
{
302-
"video": handle_file(
303-
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
304-
)
305-
},
298+
handle_file(
299+
"https://huggingface.co/spaces/gradio/video_component/resolve/main/files/a.mp4"
300+
),
306301
fn_index=0,
307302
)
308-
assert Path(job.result()["video"]).exists()
309-
assert (
310-
Path(temp_dir).resolve()
311-
in Path(job.result()["video"]).resolve().parents
312-
)
303+
assert Path(job.result()).exists()
304+
assert Path(temp_dir).resolve() in Path(job.result()).resolve().parents
313305

314306
def test_progress_updates(self, progress_demo):
315307
with connect(progress_demo) as client:
@@ -1047,23 +1039,10 @@ def test_upload(self):
10471039
with patch("builtins.open", MagicMock()):
10481040
with patch.object(pathlib.Path, "name") as mock_name:
10491041
mock_name.side_effect = lambda x: x
1050-
results = client.endpoints[0]._upload(
1051-
["pre1", ["pre2", "pre3", "pre4"], ["pre5", "pre6"], "pre7"]
1042+
results = client.endpoints[0]._upload_file(
1043+
handle_file(__file__), data_index=0
10521044
)
1053-
1054-
res = []
1055-
for re in results:
1056-
if isinstance(re, list):
1057-
res.append([r["name"] for r in re])
1058-
else:
1059-
res.append(re["name"])
1060-
1061-
assert res == [
1062-
"file1",
1063-
["file2", "file3", "file4"],
1064-
["file5", "file6"],
1065-
"file7",
1066-
]
1045+
assert results["path"] == "file1"
10671046

10681047
@pytest.mark.flaky
10691048
def test_download_private_file(self, gradio_temp_dir):

gradio/blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1261,7 +1261,7 @@ def iterate_over_children(children_list):
12611261

12621262
derived_fields = ["types"]
12631263

1264-
with Blocks() as blocks:
1264+
with Blocks(theme=config.get("theme", None)) as blocks:
12651265
# ID 0 should be the root Blocks component
12661266
original_mapping[0] = root_block = Context.root_block or blocks
12671267

gradio/external.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,11 +501,13 @@ def query_huggingface_inference_endpoints(*data):
501501

502502
kwargs = dict(interface_info, **kwargs)
503503

504-
fn = kwargs.pop("fn", None)
504+
interface_fn = kwargs.pop("fn", None)
505505
inputs = kwargs.pop("inputs", None)
506506
outputs = kwargs.pop("outputs", None)
507507

508-
interface = gr.Interface(fn, inputs, outputs, **kwargs, api_name="predict")
508+
interface = gr.Interface(
509+
interface_fn, inputs, outputs, **kwargs, api_name="predict"
510+
)
509511
return interface
510512

511513

gradio/external_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,13 @@ def handle_hf_error(e: Exception):
326326
raise TooManyRequestsError() from e
327327
elif "401" in str(e) or "You must provide an api_key" in str(e):
328328
raise Error("Unauthorized, please make sure you are signed in.") from e
329+
elif isinstance(e, StopIteration):
330+
raise Error(
331+
"This model is not supported by any Hugging Face Inference Provider. "
332+
"Please check the supported models at https://huggingface.co/docs/inference-providers."
333+
) from e
329334
else:
330-
raise Error(str(e)) from e
335+
raise Error(str(e) or f"An error occurred: {type(e).__name__}") from e
331336

332337

333338
def create_endpoint_fn(

test/test_external.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
included in a separate file because of the above-mentioned dependency.
2323
"""
2424

25-
# Mark the whole module as flaky
26-
pytestmark = pytest.mark.flaky
25+
# Mark the whole module as flaky and serial
26+
pytestmark = [pytest.mark.flaky, pytest.mark.serial]
2727

2828
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
2929

@@ -72,7 +72,7 @@ def test_text_generation(self):
7272
def test_summarization(self):
7373
model_type = "summarization"
7474
interface = gr.load(
75-
"models/facebook/bart-large-cnn", hf_token=HF_TOKEN, alias=model_type
75+
"models/facebook/bart-large-cnn", token=HF_TOKEN, alias=model_type
7676
)
7777
assert interface.__name__ == model_type
7878
assert interface.input_components and interface.output_components
@@ -82,7 +82,7 @@ def test_summarization(self):
8282
def test_translation(self):
8383
model_type = "translation"
8484
interface = gr.load(
85-
"models/facebook/bart-large-cnn", hf_token=HF_TOKEN, alias=model_type
85+
"models/facebook/bart-large-cnn", token=HF_TOKEN, alias=model_type
8686
)
8787
assert interface.__name__ == model_type
8888
assert interface.input_components and interface.output_components
@@ -93,7 +93,7 @@ def test_text_classification(self):
9393
model_type = "text-classification"
9494
interface = gr.load(
9595
"models/distilbert-base-uncased-finetuned-sst-2-english",
96-
hf_token=HF_TOKEN,
96+
token=HF_TOKEN,
9797
alias=model_type,
9898
)
9999
assert interface.__name__ == model_type
@@ -104,7 +104,7 @@ def test_text_classification(self):
104104
def test_fill_mask(self):
105105
model_type = "fill-mask"
106106
interface = gr.load(
107-
"models/bert-base-uncased", hf_token=HF_TOKEN, alias=model_type
107+
"models/bert-base-uncased", token=HF_TOKEN, alias=model_type
108108
)
109109
assert interface.__name__ == model_type
110110
assert interface.input_components and interface.output_components
@@ -114,7 +114,7 @@ def test_fill_mask(self):
114114
def test_zero_shot_classification(self):
115115
model_type = "zero-shot-classification"
116116
interface = gr.load(
117-
"models/facebook/bart-large-mnli", hf_token=HF_TOKEN, alias=model_type
117+
"models/facebook/bart-large-mnli", token=HF_TOKEN, alias=model_type
118118
)
119119
assert interface.__name__ == model_type
120120
assert interface.input_components and interface.output_components
@@ -126,7 +126,7 @@ def test_zero_shot_classification(self):
126126
def test_automatic_speech_recognition(self):
127127
model_type = "automatic-speech-recognition"
128128
interface = gr.load(
129-
"models/facebook/wav2vec2-base-960h", hf_token=HF_TOKEN, alias=model_type
129+
"models/facebook/wav2vec2-base-960h", token=HF_TOKEN, alias=model_type
130130
)
131131
assert interface.__name__ == model_type
132132
assert interface.input_components and interface.output_components
@@ -136,7 +136,7 @@ def test_automatic_speech_recognition(self):
136136
def test_image_classification(self):
137137
model_type = "image-classification"
138138
interface = gr.load(
139-
"models/google/vit-base-patch16-224", hf_token=HF_TOKEN, alias=model_type
139+
"models/google/vit-base-patch16-224", token=HF_TOKEN, alias=model_type
140140
)
141141
assert interface.__name__ == model_type
142142
assert interface.input_components and interface.output_components
@@ -147,7 +147,7 @@ def test_feature_extraction(self):
147147
model_type = "feature-extraction"
148148
interface = gr.load(
149149
"models/sentence-transformers/distilbert-base-nli-mean-tokens",
150-
hf_token=HF_TOKEN,
150+
token=HF_TOKEN,
151151
alias=model_type,
152152
)
153153
assert interface.__name__ == model_type
@@ -159,7 +159,7 @@ def test_sentence_similarity(self):
159159
model_type = "text-to-speech"
160160
interface = gr.load(
161161
"models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
162-
hf_token=HF_TOKEN,
162+
token=HF_TOKEN,
163163
alias=model_type,
164164
)
165165
assert interface.__name__ == model_type
@@ -171,7 +171,7 @@ def test_text_to_speech(self):
171171
model_type = "text-to-speech"
172172
interface = gr.load(
173173
"models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train",
174-
hf_token=HF_TOKEN,
174+
token=HF_TOKEN,
175175
alias=model_type,
176176
)
177177
assert interface.__name__ == model_type
@@ -187,7 +187,7 @@ def test_multiple_spaces_one_private(self):
187187
with gr.Blocks():
188188
gr.load(
189189
"spaces/gradio-tests/not-actually-private-spacev4-sse",
190-
hf_token=HF_TOKEN,
190+
token=HF_TOKEN,
191191
)
192192
gr.load(
193193
"spaces/gradio/test-loading-examplesv4-sse",
@@ -197,12 +197,12 @@ def test_multiple_spaces_one_private(self):
197197
def test_private_space_v4_sse_v1(self):
198198
io = gr.load(
199199
"spaces/gradio-tests/not-actually-private-spacev4-sse-v1",
200-
hf_token=HF_TOKEN,
200+
token=HF_TOKEN,
201201
)
202202
try:
203203
output = io("abc")
204204
assert output == "abc"
205-
assert io.theme.name == "gradio/monochrome"
205+
assert io._deprecated_theme == "gradio/monochrome"
206206
except TooManyRequestsError:
207207
pass
208208

@@ -227,7 +227,7 @@ def test_interface_load_cache_examples(self, tmp_path):
227227
name="models/google/vit-base-patch16-224",
228228
examples=[Path(test_file_dir, "cheetah1.jpg")],
229229
cache_examples=True,
230-
hf_token=HF_TOKEN,
230+
token=HF_TOKEN,
231231
)
232232
except TooManyRequestsError:
233233
pass
@@ -302,8 +302,8 @@ def check_dataframe(config):
302302
c for c in config["components"] if c["props"].get("label", "") == "Input Rows"
303303
)
304304
assert input_df["props"]["headers"] == ["a", "b"]
305-
assert input_df["props"]["row_count"] == (1, "dynamic")
306-
assert input_df["props"]["col_count"] == (2, "fixed")
305+
assert input_df["props"]["row_count"] == [3, "dynamic"]
306+
assert input_df["props"]["col_count"] == [2, "dynamic"]
307307

308308

309309
def check_dataset(config, readme_examples):
@@ -352,6 +352,8 @@ def test_use_api_name_in_call_method():
352352

353353

354354
def test_load_custom_component():
355+
pytest.skip("Custom components not supported yet")
356+
355357
from gradio_pdf import PDF # noqa
356358

357359
demo = gr.load("spaces/freddyaboulton/gradiopdf")
@@ -363,7 +365,7 @@ def test_load_custom_component():
363365

364366
def test_load_inside_blocks():
365367
demo = gr.load("spaces/abidlabs/en2fr")
366-
output = demo("Hello")
368+
output = demo("Hello", api_name="predict")
367369
assert isinstance(output, str)
368370

369371

0 commit comments

Comments
 (0)