Skip to content

Commit 9a7ceb0

Browse files
authored
Default to jit_compile=True for SD (#2054)
1 parent a83393b commit 9a7ceb0

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

keras_cv/models/stable_diffusion/stable_diffusion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self,
5353
img_height=512,
5454
img_width=512,
55-
jit_compile=False,
55+
jit_compile=True,
5656
):
5757
# UNet requires multiples of 2**7 = 128
5858
img_height = round(img_height / 128) * 128
@@ -396,7 +396,7 @@ def __init__(
396396
self,
397397
img_height=512,
398398
img_width=512,
399-
jit_compile=False,
399+
jit_compile=True,
400400
):
401401
super().__init__(img_height, img_width, jit_compile)
402402
print(
@@ -482,7 +482,7 @@ def __init__(
482482
self,
483483
img_height=512,
484484
img_width=512,
485-
jit_compile=False,
485+
jit_compile=True,
486486
):
487487
super().__init__(img_height, img_width, jit_compile)
488488
print(

keras_cv/models/stable_diffusion/stable_diffusion_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def test_generate_image_rejects_noise_and_seed(self):
100100

101101
@pytest.mark.extra_large
102102
class StableDiffusionMultiFrameworkTest(TestCase):
103+
@pytest.mark.filterwarnings("ignore::UserWarning") # Torch + jit_compile
103104
def test_end_to_end(self):
104105
prompt = "a caterpillar smoking a hookah while sitting on a mushroom"
105106
stablediff = StableDiffusion(128, 128)

0 commit comments

Comments
 (0)