Skip to content

Commit a9d4197

Browse files
committed
fix
1 parent de5fad1 commit a9d4197

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tests/pipelines/sana/test_sana_sprint_img2img.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def get_dummy_components(self):
120120
num_attention_heads=2,
121121
num_hidden_layers=1,
122122
num_key_value_heads=2,
123-
vocab_size=1000,
123+
vocab_size=8,
124124
attn_implementation="eager",
125125
)
126126
text_encoder = Gemma2Model(text_encoder_config)
@@ -169,6 +169,9 @@ def test_inference(self):
169169
generated_image = image[0]
170170

171171
self.assertEqual(generated_image.shape, (3, 32, 32))
172+
expected_image = torch.randn(3, 32, 32)
173+
max_diff = np.abs(generated_image - expected_image).max()
174+
self.assertLessEqual(max_diff, 1e10)
172175

173176
def test_callback_inputs(self):
174177
sig = inspect.signature(self.pipeline_class.__call__)

0 commit comments

Comments
 (0)